diff --git a/.husky/pre-commit b/.husky/pre-commit index fa413d1..b464709 100644 --- a/.husky/pre-commit +++ b/.husky/pre-commit @@ -1 +1,2 @@ pnpm --no-install prettier:check +pnpm --no-install lint:sol diff --git a/contracts/test/utils/TestEncryptedErrors.sol b/contracts/test/utils/TestEncryptedErrors.sol new file mode 100644 index 0000000..95be14f --- /dev/null +++ b/contracts/test/utils/TestEncryptedErrors.sol @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: BSD-3-Clause-Clear +pragma solidity ^0.8.24; + +import "fhevm/lib/TFHE.sol"; +import { EncryptedErrors } from "../../utils/EncryptedErrors.sol"; +import { MockZamaFHEVMConfig } from "fhevm/config/ZamaFHEVMConfig.sol"; + +contract TestEncryptedErrors is MockZamaFHEVMConfig, EncryptedErrors { + constructor(uint8 totalNumberErrorCodes_) EncryptedErrors(totalNumberErrorCodes_) { + for (uint8 i; i <= totalNumberErrorCodes_; i++) { + /// @dev It is not possible to access the _errorCodeDefinitions since it is private. + TFHE.allow(TFHE.asEuint8(i), msg.sender); + } + } + + function errorChangeIf( + einput encryptedCondition, + einput encryptedErrorCode, + bytes calldata inputProof, + uint8 indexCode + ) external returns (euint8 newErrorCode) { + ebool condition = TFHE.asEbool(encryptedCondition, inputProof); + euint8 errorCode = TFHE.asEuint8(encryptedErrorCode, inputProof); + newErrorCode = _errorChangeIf(condition, indexCode, errorCode); + _errorSave(newErrorCode); + TFHE.allow(newErrorCode, msg.sender); + } + + function errorChangeIfNot( + einput encryptedCondition, + einput encryptedErrorCode, + bytes calldata inputProof, + uint8 indexCode + ) external returns (euint8 newErrorCode) { + ebool condition = TFHE.asEbool(encryptedCondition, inputProof); + euint8 errorCode = TFHE.asEuint8(encryptedErrorCode, inputProof); + newErrorCode = _errorChangeIfNot(condition, indexCode, errorCode); + _errorSave(newErrorCode); + TFHE.allow(newErrorCode, msg.sender); + } + + function errorDefineIf( + einput encryptedCondition, + bytes calldata inputProof, + uint8 indexCode + ) external returns (euint8 errorCode) { + ebool condition = TFHE.asEbool(encryptedCondition, inputProof); + errorCode = _errorDefineIf(condition, indexCode); + _errorSave(errorCode); + TFHE.allow(errorCode, msg.sender); + } + + function errorDefineIfNot( + einput encryptedCondition, + bytes calldata inputProof, + uint8 indexCode + ) external returns (euint8 errorCode) { + ebool condition = TFHE.asEbool(encryptedCondition, inputProof); + errorCode = _errorDefineIfNot(condition, indexCode); + _errorSave(errorCode); + TFHE.allow(errorCode, msg.sender); + } + + function errorGetCodeDefinition(uint8 indexCodeDefinition) external view returns (euint8 errorCode) { + errorCode = _errorGetCodeDefinition(indexCodeDefinition); + } + + function errorGetCodeEmitted(uint256 errorId) external view returns (euint8 errorCode) { + errorCode = _errorGetCodeEmitted(errorId); + } + + function errorGetCounter() external view returns (uint256 countErrors) { + countErrors = _errorGetCounter(); + } + + function errorGetNumCodesDefined() external view returns (uint8 totalNumberErrorCodes) { + totalNumberErrorCodes = _errorGetNumCodesDefined(); + } +} diff --git a/contracts/token/ERC20/EncryptedERC20.sol b/contracts/token/ERC20/EncryptedERC20.sol index dd0528e..60ffa60 100644 --- a/contracts/token/ERC20/EncryptedERC20.sol +++ b/contracts/token/ERC20/EncryptedERC20.sol @@ -2,7 +2,10 @@ pragma solidity ^0.8.24; import "fhevm/lib/TFHE.sol"; + +import { IERC20Errors } from "@openzeppelin/contracts/interfaces/draft-IERC6093.sol"; import { IEncryptedERC20 } from "./IEncryptedERC20.sol"; +import { TFHEErrors } from "../../utils/TFHEErrors.sol"; /** * @title EncryptedERC20 @@ -12,10 +15,9 @@ import { IEncryptedERC20 } from "./IEncryptedERC20.sol"; * and setting allowances, but uses encrypted data types. * The total supply is not encrypted. */ -abstract contract EncryptedERC20 is IEncryptedERC20 { +abstract contract EncryptedERC20 is IEncryptedERC20, IERC20Errors, TFHEErrors { /// @notice used as a placehoder in Approval and Transfer events to comply with the official EIP20 uint256 internal constant _PLACEHOLDER = type(uint256).max; - /// @notice Total supply. uint64 internal _totalSupply; @@ -31,11 +33,6 @@ abstract contract EncryptedERC20 is IEncryptedERC20 { /// @notice A mapping of the form mapping(account => mapping(spender => allowance)). mapping(address account => mapping(address spender => euint64 allowance)) internal _allowances; - /** - * @notice Error when the `sender` is not allowed to access a value. - */ - error TFHESenderNotAllowed(); - /** * @param name_ Name of the token. * @param symbol_ Symbol. @@ -151,6 +148,14 @@ abstract contract EncryptedERC20 is IEncryptedERC20 { } function _approve(address owner, address spender, euint64 amount) internal virtual { + if (owner == address(0)) { + revert ERC20InvalidApprover(owner); + } + + if (spender == address(0)) { + revert ERC20InvalidSpender(spender); + } + _allowances[owner][spender] = amount; TFHE.allowThis(amount); TFHE.allow(amount, owner); @@ -184,11 +189,11 @@ abstract contract EncryptedERC20 is IEncryptedERC20 { function _transferNoEvent(address from, address to, euint64 amount, ebool isTransferable) internal virtual { if (from == address(0)) { - revert SenderAddressNull(); + revert ERC20InvalidSender(from); } if (to == address(0)) { - revert ReceiverAddressNull(); + revert ERC20InvalidReceiver(to); } /// Add to the balance of `to` and subract from the balance of `from`. diff --git a/contracts/token/ERC20/IEncryptedERC20.sol b/contracts/token/ERC20/IEncryptedERC20.sol index c15c43a..463b246 100644 --- a/contracts/token/ERC20/IEncryptedERC20.sol +++ b/contracts/token/ERC20/IEncryptedERC20.sol @@ -22,16 +22,6 @@ interface IEncryptedERC20 { */ event Transfer(address indexed from, address indexed to, uint256 errorId); - /** - * @notice Returned when receiver is address(0). - */ - error ReceiverAddressNull(); - - /** - * @notice Returned when sender is address(0). - */ - error SenderAddressNull(); - /** * @notice Sets the `encryptedAmount` as the allowance of `spender` over the caller's tokens. */ diff --git a/contracts/token/ERC20/extensions/EncryptedERC20WithErrors.sol b/contracts/token/ERC20/extensions/EncryptedERC20WithErrors.sol index d52933f..6902178 100644 --- a/contracts/token/ERC20/extensions/EncryptedERC20WithErrors.sol +++ b/contracts/token/ERC20/extensions/EncryptedERC20WithErrors.sol @@ -30,8 +30,8 @@ abstract contract EncryptedERC20WithErrors is EncryptedERC20, EncryptedErrors { } /** - * @param name_ Name of the token. - * @param symbol_ Symbol. + * @param name_ Name of the token. + * @param symbol_ Symbol. */ constructor( string memory name_, @@ -43,7 +43,7 @@ abstract contract EncryptedERC20WithErrors is EncryptedERC20, EncryptedErrors { */ function transfer(address to, euint64 amount) public virtual override returns (bool) { _isSenderAllowedForAmount(amount); - /// Check whether the owner has enough tokens. + /// @dev Check whether the owner has enough tokens. ebool canTransfer = TFHE.le(amount, _balances[msg.sender]); euint8 errorCode = _errorDefineIfNot(canTransfer, uint8(ErrorCodes.UNSUFFICIENT_BALANCE)); _errorSave(errorCode); @@ -53,15 +53,6 @@ abstract contract EncryptedERC20WithErrors is EncryptedERC20, EncryptedErrors { return true; } - function getErrorCodeForTransferId(uint256 transferId) public view virtual returns (euint8) { - return _errorGetCodeEmitted(transferId); - } - - function _transfer(address from, address to, euint64 amount, ebool isTransferable) internal override { - _transferNoEvent(from, to, amount, isTransferable); - emit Transfer(from, to, _errorGetCounter() - 1); - } - /** * @notice See {IEncryptedERC20-transferFrom}. */ @@ -73,21 +64,37 @@ abstract contract EncryptedERC20WithErrors is EncryptedERC20, EncryptedErrors { return true; } + /** + * @notice Returns the error for a transfer id. + * @param transferId Transfer id. It can read from the `Transfer` event. + * @return errorCode Encrypted error code. + */ + function getErrorCodeForTransferId(uint256 transferId) public view virtual returns (euint8 errorCode) { + errorCode = _errorGetCodeEmitted(transferId); + } + + function _transfer(address from, address to, euint64 amount, ebool isTransferable) internal override { + _transferNoEvent(from, to, amount, isTransferable); + /// @dev It was incremented in _saveError. + emit Transfer(from, to, _errorGetCounter() - 1); + } + function _updateAllowance( address owner, address spender, euint64 amount ) internal virtual override returns (ebool isTransferable) { euint64 currentAllowance = _allowance(owner, spender); - /// Make sure sure the allowance suffices. + /// @dev It checks whether the allowance suffices. ebool allowedTransfer = TFHE.le(amount, currentAllowance); euint8 errorCode = _errorDefineIfNot(allowedTransfer, uint8(ErrorCodes.UNSUFFICIENT_APPROVAL)); - /// Make sure the owner has enough tokens. + /// @dev It checks that the owner has enough tokens. ebool canTransfer = TFHE.le(amount, _balances[owner]); ebool isNotTransferableButIsApproved = TFHE.and(TFHE.not(canTransfer), allowedTransfer); errorCode = _errorChangeIf( - isNotTransferableButIsApproved, // should indeed check that spender is approved to not leak information - // on balance of `from` to unauthorized spender via calling reencryptTransferError afterwards + isNotTransferableButIsApproved, + /// @dev Should indeed check that spender is approved to not leak information. + /// on balance of `from` to unauthorized spender via calling reencryptTransferError afterwards. uint8(ErrorCodes.UNSUFFICIENT_BALANCE), errorCode ); diff --git a/contracts/utils/EncryptedErrors.sol b/contracts/utils/EncryptedErrors.sol index 2220c88..4c7cd51 100644 --- a/contracts/utils/EncryptedErrors.sol +++ b/contracts/utils/EncryptedErrors.sol @@ -4,157 +4,195 @@ pragma solidity ^0.8.24; import "fhevm/lib/TFHE.sol"; /** - * @notice This abstract contract is used for error handling in the fhEVM. - * Error codes are encrypted in the constructor inside the `errorCodes` mapping. - * @dev `errorCodes[0]` should always refer to the `NO_ERROR` code, by default. + * @title EncryptedErrors. + * @notice This abstract contract is used for error handling in the fhEVM. + * Error codes are encrypted in the constructor inside the `_errorCodeDefinitions` mapping. + * @dev `_errorCodeDefinitions[0]` should always refer to the `NO_ERROR` code, by default. */ abstract contract EncryptedErrors { - /// @notice The total number of errors is equal to zero. - error TotalNumberErrorCodesEqualToZero(); - - /// @notice Error index is invalid. + /// @notice Returned if the error index is invalid. error ErrorIndexInvalid(); - /// @notice Error index is null. + /// @notice Returned if the error index is null. error ErrorIndexIsNull(); + /// @notice Returned if the total number of errors is equal to zero. + error TotalNumberErrorCodesEqualToZero(); + /// @notice Total number of error codes. - /// @dev Should hold the constant size of the _errorCodesDefinitions mapping + /// @dev Should hold the constant size of the `_errorCodeDefinitions` mapping. uint8 private immutable _TOTAL_NUMBER_ERROR_CODES; + /// @notice Used to keep track of number of emitted errors. + /// @dev Should hold the size of the _errorCodesEmitted mapping. + uint256 private _errorCounter; + /// @notice Mapping of trivially encrypted error codes definitions. - /// @dev In storage because solc does not support immutable mapping, neither immutable arrays, yet - mapping(uint8 errorCode => euint8 encryptedErrorCode) private _errorCodesDefinitions; + /// @dev In storage because solc does not support immutable mapping, neither immutable arrays, yet. + mapping(uint8 errorCode => euint8 encryptedErrorCode) private _errorCodeDefinitions; /// @notice Mapping of encrypted error codes emitted. mapping(uint256 errorIndex => euint8 encryptedErrorCode) private _errorCodesEmitted; - /// @notice Used to keep track of number of emitted errors - /// @dev Should hold the size of the _errorCodesEmitted mapping - uint256 private _errorCounter; - /** - * @notice Sets the non-null value for `_TOTAL_NUMBER_ERROR_CODES` corresponding to the total number of errors. - * @param totalNumberErrorCodes_ total number of different errors. - * @dev `totalNumberErrorCodes_` must be non-null (`_errorCodesDefinitions[0]` corresponds to the `NO_ERROR` code). + * @notice Sets the non-null value for `_TOTAL_NUMBER_ERROR_CODES` + * corresponding to the total number of errors. + * @param totalNumberErrorCodes_ Total number of different errors. + * @dev `totalNumberErrorCodes_` must be non-null + * (`_errorCodeDefinitions[0]` corresponds to the `NO_ERROR` code). */ constructor(uint8 totalNumberErrorCodes_) { if (totalNumberErrorCodes_ == 0) { revert TotalNumberErrorCodesEqualToZero(); } + for (uint8 i; i <= totalNumberErrorCodes_; i++) { euint8 errorCode = TFHE.asEuint8(i); - _errorCodesDefinitions[i] = errorCode; + _errorCodeDefinitions[i] = errorCode; TFHE.allowThis(errorCode); } + _TOTAL_NUMBER_ERROR_CODES = totalNumberErrorCodes_; } /** - * @notice Returns the trivially encrypted error code at index `indexCodeDefinition`. - * @param indexCodeDefinition the index of the requested error code definition. - * @return the trivially encrypted error code located at `indexCodeDefinition` of _errorCodesDefinitions mapping. + * @notice Computes an encrypted error code, result will be either a reencryption of + * `_errorCodeDefinitions[indexCode]` if `condition` is an encrypted `true` + * or of `errorCode` otherwise. + * @param condition Encrypted boolean used in the select operator. + * @param errorCode Selected error code if `condition` encrypts `true`. + * @return newErrorCode New reencrypted error code depending on `condition` value. + * @dev ` indexCode` must be below the total number of error codes. */ - function _errorGetCodeDefinition(uint8 indexCodeDefinition) internal view returns (euint8) { - if (indexCodeDefinition >= _TOTAL_NUMBER_ERROR_CODES) { + function _errorChangeIf( + ebool condition, + uint8 indexCode, + euint8 errorCode + ) internal virtual returns (euint8 newErrorCode) { + if (indexCode > _TOTAL_NUMBER_ERROR_CODES) { revert ErrorIndexInvalid(); } - return _errorCodesDefinitions[indexCodeDefinition]; + + newErrorCode = TFHE.select(condition, _errorCodeDefinitions[indexCode], errorCode); } /** - * @notice Returns the total counter of emitted of error codes. - * @return the number of errors emitted. + * @notice Does the opposite of `changeErrorIf`, i.e result will be either a reencryption of + * `_errorCodeDefinitions[indexCode]` if `condition` is an encrypted `false` + * or of `errorCode` otherwise. + * @param condition The encrypted boolean used in the `TFHE.select`. + * @param errorCode The selected error code if `condition` encrypts `false`. + * @return newErrorCode New error code depending on `condition` value. + * @dev `indexCode` must be below the total number of error codes. */ - function _errorGetCounter() internal view returns (uint256) { - return _errorCounter; + function _errorChangeIfNot( + ebool condition, + uint8 indexCode, + euint8 errorCode + ) internal virtual returns (euint8 newErrorCode) { + if (indexCode > _TOTAL_NUMBER_ERROR_CODES) { + revert ErrorIndexInvalid(); + } + + newErrorCode = TFHE.select(condition, errorCode, _errorCodeDefinitions[indexCode]); } /** - * @notice Returns the total number of the possible error codes defined. - * @return the total number of the different possible error codes. + * @notice Computes an encrypted error code, result will be either a reencryption of + * `_errorCodeDefinitions[indexCode]` if `condition` is an encrypted `true` + * or of `NO_ERROR` otherwise. + * @param condition Encrypted boolean used in the select operator. + * @param indexCode Index of the selected error code if `condition` encrypts `true`. + * @return errorCode Reencrypted error code depending on `condition` value. + * @dev `indexCode` must be non-null and below the total number of defined error codes. */ - function _errorGetNumCodesDefined() internal view returns (uint8) { - return _TOTAL_NUMBER_ERROR_CODES; + function _errorDefineIf(ebool condition, uint8 indexCode) internal virtual returns (euint8 errorCode) { + if (indexCode == 0) { + revert ErrorIndexIsNull(); + } + + if (indexCode > _TOTAL_NUMBER_ERROR_CODES) { + revert ErrorIndexInvalid(); + } + + errorCode = TFHE.select(condition, _errorCodeDefinitions[indexCode], _errorCodeDefinitions[0]); } /** - * @notice Returns the encrypted error code which was stored in the `_errorCodesEmitted` mapping at key `errorId`. - * @param errorId the requested key stored in the `_errorCodesEmitted` mapping. - * @return the encrypted error code located at the `errorId` key. - * @dev `errorId` must be a valid id, i.e below the error counter. + * @notice Does the opposite of `defineErrorIf`, i.e result will be either a reencryption of + * `_errorCodeDefinitions[indexCode]` if `condition` is an encrypted `false` or + * of `NO_ERROR` otherwise. + * @param condition Encrypted boolean used in the select operator. + * @param indexCode Index of the selected error code if `condition` encrypts `false`. + * @return errorCode Reencrypted error code depending on `condition` value. + * @dev `indexCode` must be non-null and below the total number of defined error codes. */ - function _errorGetCodeEmitted(uint256 errorId) internal view returns (euint8) { - if (errorId >= _errorCounter) revert ErrorIndexInvalid(); - return _errorCodesEmitted[errorId]; + function _errorDefineIfNot(ebool condition, uint8 indexCode) internal virtual returns (euint8 errorCode) { + if (indexCode == 0) { + revert ErrorIndexIsNull(); + } + + if (indexCode > _TOTAL_NUMBER_ERROR_CODES) { + revert ErrorIndexInvalid(); + } + + errorCode = TFHE.select(condition, _errorCodeDefinitions[0], _errorCodeDefinitions[indexCode]); } /** - * @notice Computes an encrypted error code, result will be either a reencryption of - * `_errorCodesDefinitions[indexCode]` if `condition` is an encrypted `true` or of `NO_ERROR` otherwise. - * @param condition the encrypted boolean used in the select operator. - * @param indexCode the index of the selected error code if `condition` encrypts `true`. - * @return the reencrypted error code depending on `condition` value. - * @dev `indexCode` must be non-null and below the total number of defined error codes. + * @notice Saves `errorCode` in storage, in the `_errorCodesEmitted` mapping. + * @param errorCode Encrypted error code to be saved in storage. + * @return errorId The `errorId` key in `_errorCodesEmitted` where `errorCode` is stored. */ - function _errorDefineIf(ebool condition, uint8 indexCode) internal returns (euint8) { - if (indexCode == 0) revert ErrorIndexIsNull(); - if (indexCode > _TOTAL_NUMBER_ERROR_CODES) revert ErrorIndexInvalid(); - euint8 errorCode = TFHE.select(condition, _errorCodesDefinitions[indexCode], _errorCodesDefinitions[0]); - return errorCode; + function _errorSave(euint8 errorCode) internal virtual returns (uint256 errorId) { + errorId = _errorCounter; + _errorCounter++; + _errorCodesEmitted[errorId] = errorCode; + + TFHE.allowThis(errorCode); } /** - * @notice Does the opposite of `defineErrorIf`, i.e result will be either a reencryption of - * `_errorCodesDefinitions[indexCode]` if `condition` is an encrypted `false` or of `NO_ERROR` otherwise. - * @param condition the encrypted boolean used in the select operator. - * @param indexCode the index of the selected error code if `condition` encrypts `false`. - * @return the reencrypted error code depending on `condition` value. - * @dev `indexCode` must be non-null and below the total number of defined error codes. + * @notice Returns the trivially encrypted error code at index `indexCodeDefinition`. + * @param indexCodeDefinition Index of the requested error code definition. + * @return errorCode Encrypted error code located at `indexCodeDefinition` in `_errorCodeDefinitions`. */ - function _errorDefineIfNot(ebool condition, uint8 indexCode) internal returns (euint8) { - if (indexCode == 0) revert ErrorIndexIsNull(); - if (indexCode > _TOTAL_NUMBER_ERROR_CODES) revert ErrorIndexInvalid(); - euint8 errorCode = TFHE.select(condition, _errorCodesDefinitions[0], _errorCodesDefinitions[indexCode]); - return errorCode; + function _errorGetCodeDefinition(uint8 indexCodeDefinition) internal view virtual returns (euint8 errorCode) { + if (indexCodeDefinition >= _TOTAL_NUMBER_ERROR_CODES) { + revert ErrorIndexInvalid(); + } + + errorCode = _errorCodeDefinitions[indexCodeDefinition]; } /** - * @notice Computes an encrypted error code, result will be either a reencryption of - * `_errorCodesDefinitions[indexCode]` if `condition` is an encrypted `true` or of `errorCode` otherwise. - * @param condition the encrypted boolean used in the select operator. - * @param errorCode the selected error code if `condition` encrypts `true`. - * @return the reencrypted error code depending on `condition` value. - * @dev `indexCode` must be below the total number of error codes. + * @notice Returns the encrypted error code which was stored in `_errorCodesEmitted` + * at key `errorId`. + * @param errorId Requested key stored in the `_errorCodesEmitted` mapping. + * @return errorCode Encrypted error code located at the `errorId` key. + * @dev `errorId` must be a valid id, i.e below the error counter. */ - function _errorChangeIf(ebool condition, uint8 indexCode, euint8 errorCode) internal returns (euint8) { - if (indexCode >= _TOTAL_NUMBER_ERROR_CODES) revert ErrorIndexInvalid(); - return TFHE.select(condition, _errorCodesDefinitions[indexCode], errorCode); + function _errorGetCodeEmitted(uint256 errorId) internal view virtual returns (euint8 errorCode) { + if (errorId >= _errorCounter) { + revert ErrorIndexInvalid(); + } + + errorCode = _errorCodesEmitted[errorId]; } /** - * @notice Does the opposite of `changeErrorIf`, i.e result will be either a reencryption of - * `_errorCodesDefinitions[indexCode]` if `condition` is an encrypted `false` or of `errorCode` otherwise. - * @param condition the encrypted boolean used in the cmux. - * @param errorCode the selected error code if `condition` encrypts `false`. - * @return the reencrypted error code depending on `condition` value. - * @dev `indexCode` must be below the total number of error codes. + * @notice Returns the total counter of emitted of error codes. + * @return countErrors Number of errors emitted. */ - function _errorChangeIfNot(ebool condition, uint8 indexCode, euint8 errorCode) internal returns (euint8) { - if (indexCode >= _TOTAL_NUMBER_ERROR_CODES) revert ErrorIndexInvalid(); - return TFHE.select(condition, errorCode, _errorCodesDefinitions[indexCode]); + function _errorGetCounter() internal view virtual returns (uint256 countErrors) { + countErrors = _errorCounter; } /** - * @notice Saves `errorCode` in storage, in the `_errorCodesEmitted` mapping, at the lowest unused key. - * @param errorCode the encrypted error code to be saved in storage. - * @return the `errorId` key in `_errorCodesEmitted` where `errorCode` is stored. + * @notice Returns the total number of the possible error codes defined. + * @return totalNumberErrorCodes Total number of the different possible error codes. */ - function _errorSave(euint8 errorCode) internal returns (uint256) { - uint256 errorId = _errorCounter; - _errorCounter++; - _errorCodesEmitted[errorId] = errorCode; - TFHE.allowThis(errorCode); - return errorId; + function _errorGetNumCodesDefined() internal view virtual returns (uint8 totalNumberErrorCodes) { + totalNumberErrorCodes = _TOTAL_NUMBER_ERROR_CODES; } } diff --git a/contracts/utils/TFHEErrors.sol b/contracts/utils/TFHEErrors.sol new file mode 100644 index 0000000..923ba04 --- /dev/null +++ b/contracts/utils/TFHEErrors.sol @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: BSD-3-Clause-Clear +pragma solidity ^0.8.24; + +interface TFHEErrors { + /** + * @notice Returned when the `sender` is not allowed to access a value. + */ + error TFHESenderNotAllowed(); +} diff --git a/test/encryptedERC20/EncryptedERC20.test.ts b/test/encryptedERC20/EncryptedERC20.test.ts index 98830a1..608d0ef 100644 --- a/test/encryptedERC20/EncryptedERC20.test.ts +++ b/test/encryptedERC20/EncryptedERC20.test.ts @@ -285,7 +285,7 @@ describe("EncryptedERC20", function () { [ "transfer(address,bytes32,bytes)" ](NULL_ADDRESS, encryptedTransferAmount.handles[0], encryptedTransferAmount.inputProof), - ).to.be.revertedWithCustomError(this.encryptedERC20, "ReceiverAddressNull"); + ).to.be.revertedWithCustomError(this.encryptedERC20, "ERC20InvalidReceiver"); }); it("sender who is not allowed cannot transfer using a handle from another account", async function () { diff --git a/test/encryptedERC20/EncryptedERC20WithErrors.test.ts b/test/encryptedERC20/EncryptedERC20WithErrors.test.ts index 22f381d..6ebb43c 100644 --- a/test/encryptedERC20/EncryptedERC20WithErrors.test.ts +++ b/test/encryptedERC20/EncryptedERC20WithErrors.test.ts @@ -194,7 +194,7 @@ describe("EncryptedERC20WithErrors", function () { await reencryptBalance(this.signers, this.instances, "bob", this.encryptedERC20, this.encryptedERC20Address), ).to.equal(0); // check that transfer did not happen, as expected - // Check that the error code matches if balance is not sufficient + // Check that the error code matches if approval is not sufficient expect( await checkErrorCode( this.signers, @@ -240,7 +240,7 @@ describe("EncryptedERC20WithErrors", function () { ), ).to.equal(0); - // Check that the error code matches if allowance is not sufficient + // Check that the error code matches if there is no error expect( await checkErrorCode( this.signers, @@ -332,6 +332,26 @@ describe("EncryptedERC20WithErrors", function () { } }); + it("spender cannot be null address", async function () { + const NULL_ADDRESS = "0x0000000000000000000000000000000000000000"; + const mintAmount = 100_000; + const transferAmount = 50_000; + const tx = await this.encryptedERC20.connect(this.signers.alice).mint(mintAmount); + await tx.wait(); + + const input = this.instances.alice.createEncryptedInput(this.encryptedERC20Address, this.signers.alice.address); + input.add64(transferAmount); + const encryptedTransferAmount = await input.encrypt(); + + await expect( + this.encryptedERC20 + .connect(this.signers.alice) + [ + "approve(address,bytes32,bytes)" + ](NULL_ADDRESS, encryptedTransferAmount.handles[0], encryptedTransferAmount.inputProof), + ).to.be.revertedWithCustomError(this.encryptedERC20, "ERC20InvalidSpender"); + }); + it("receiver cannot be null address", async function () { const NULL_ADDRESS = "0x0000000000000000000000000000000000000000"; const mintAmount = 100_000; @@ -349,7 +369,7 @@ describe("EncryptedERC20WithErrors", function () { [ "transfer(address,bytes32,bytes)" ](NULL_ADDRESS, encryptedTransferAmount.handles[0], encryptedTransferAmount.inputProof), - ).to.be.revertedWithCustomError(this.encryptedERC20, "ReceiverAddressNull"); + ).to.be.revertedWithCustomError(this.encryptedERC20, "ERC20InvalidReceiver"); }); it("sender who is not allowed cannot transfer using a handle from another account", async function () { diff --git a/test/utils/EncryptedErrors.fixture.ts b/test/utils/EncryptedErrors.fixture.ts new file mode 100644 index 0000000..ba8295e --- /dev/null +++ b/test/utils/EncryptedErrors.fixture.ts @@ -0,0 +1,11 @@ +import { ethers } from "hardhat"; + +import type { TestEncryptedErrors } from "../../types"; +import { Signers } from "../signers"; + +export async function deployEncryptedErrors(signers: Signers, numberErrors: number): Promise { + const contractFactory = await ethers.getContractFactory("TestEncryptedErrors"); + const contract = await contractFactory.connect(signers.alice).deploy(numberErrors); + await contract.waitForDeployment(); + return contract; +} diff --git a/test/utils/EncryptedErrors.test.ts b/test/utils/EncryptedErrors.test.ts new file mode 100644 index 0000000..4f327ef --- /dev/null +++ b/test/utils/EncryptedErrors.test.ts @@ -0,0 +1,283 @@ +import { expect } from "chai"; +import { ethers } from "hardhat"; + +import { createInstances } from "../instance"; +import { reencryptEuint8 } from "../reencrypt"; +import { getSigners, initSigners } from "../signers"; +import { deployEncryptedErrors } from "./EncryptedErrors.fixture"; + +describe("EncryptedErrors", function () { + const NO_ERROR_CODE = BigInt(0); + + before(async function () { + await initSigners(3); + this.signers = await getSigners(); + this.instances = await createInstances(this.signers); + }); + + beforeEach(async function () { + this.numberErrors = 3; + const contract = await deployEncryptedErrors(this.signers, this.numberErrors); + this.encryptedErrorsAddress = await contract.getAddress(); + this.encryptedErrors = contract; + }); + + it("post-deployment", async function () { + expect(await this.encryptedErrors.errorGetCounter()).to.be.eq(BigInt("0")); + expect(await this.encryptedErrors.errorGetNumCodesDefined()).to.be.eq(BigInt("3")); + + for (let i = 0; i < 3; i++) { + const handle = await this.encryptedErrors.connect(this.signers.alice).errorGetCodeDefinition(i); + expect( + await reencryptEuint8(this.signers, this.instances, "alice", handle, this.encryptedErrorsAddress), + ).to.be.eq(i); + } + }); + + it("errorDefineIf --> true", async function () { + // True --> errorId=0 has errorCode=2 + const condition = true; + const targetErrorCode = 2; + + const input = this.instances.alice.createEncryptedInput(this.encryptedErrorsAddress, this.signers.alice.address); + const encryptedData = await input.addBool(condition).encrypt(); + + await this.encryptedErrors + .connect(this.signers.alice) + .errorDefineIf(encryptedData.handles[0], encryptedData.inputProof, targetErrorCode); + + const handle = await this.encryptedErrors.connect(this.signers.alice).errorGetCodeEmitted(0); + expect(await reencryptEuint8(this.signers, this.instances, "alice", handle, this.encryptedErrorsAddress)).to.be.eq( + targetErrorCode, + ); + expect(await this.encryptedErrors.errorGetCounter()).to.be.eq(BigInt("1")); + }); + + it("errorDefineIf --> false", async function () { + // False --> errorId=1 has errorCode=0 + const condition = false; + const targetErrorCode = 2; + + const input = this.instances.alice.createEncryptedInput(this.encryptedErrorsAddress, this.signers.alice.address); + const encryptedData = await input.addBool(condition).encrypt(); + + await this.encryptedErrors + .connect(this.signers.alice) + .errorDefineIf(encryptedData.handles[0], encryptedData.inputProof, targetErrorCode); + + const handle = await this.encryptedErrors.connect(this.signers.alice).errorGetCodeEmitted(0); + expect(await reencryptEuint8(this.signers, this.instances, "alice", handle, this.encryptedErrorsAddress)).to.be.eq( + NO_ERROR_CODE, + ); + expect(await this.encryptedErrors.errorGetCounter()).to.be.eq(BigInt("1")); + }); + + it("errorDefineIfNot --> true", async function () { + // True --> errorId=0 has errorCode=0 + const condition = true; + const targetErrorCode = 2; + + const input = this.instances.alice.createEncryptedInput(this.encryptedErrorsAddress, this.signers.alice.address); + const encryptedData = await input.addBool(condition).encrypt(); + + await this.encryptedErrors + .connect(this.signers.alice) + .errorDefineIfNot(encryptedData.handles[0], encryptedData.inputProof, targetErrorCode); + + const handle = await this.encryptedErrors.connect(this.signers.alice).errorGetCodeEmitted(0); + expect(await reencryptEuint8(this.signers, this.instances, "alice", handle, this.encryptedErrorsAddress)).to.be.eq( + NO_ERROR_CODE, + ); + expect(await this.encryptedErrors.errorGetCounter()).to.be.eq(BigInt("1")); + }); + + it("errorDefineIf --> false", async function () { + // False --> errorId=1 has errorCode=2 + const condition = false; + const targetErrorCode = 2; + + const input = this.instances.alice.createEncryptedInput(this.encryptedErrorsAddress, this.signers.alice.address); + const encryptedData = await input.addBool(condition).encrypt(); + + await this.encryptedErrors + .connect(this.signers.alice) + .errorDefineIfNot(encryptedData.handles[0], encryptedData.inputProof, targetErrorCode); + + const handle = await this.encryptedErrors.connect(this.signers.alice).errorGetCodeEmitted(0); + expect(await reencryptEuint8(this.signers, this.instances, "alice", handle, this.encryptedErrorsAddress)).to.be.eq( + targetErrorCode, + ); + expect(await this.encryptedErrors.errorGetCounter()).to.be.eq(BigInt("1")); + }); + + it("errorChangeIf --> true --> change error code", async function () { + // True --> change errorCode + const condition = true; + const errorCode = 1; + const targetErrorCode = 2; + + const input = this.instances.alice.createEncryptedInput(this.encryptedErrorsAddress, this.signers.alice.address); + const encryptedData = await input.addBool(condition).add8(errorCode).encrypt(); + + await this.encryptedErrors + .connect(this.signers.alice) + .errorChangeIf(encryptedData.handles[0], encryptedData.handles[1], encryptedData.inputProof, targetErrorCode); + + const handle = await this.encryptedErrors.connect(this.signers.alice).errorGetCodeEmitted(0); + expect(await reencryptEuint8(this.signers, this.instances, "alice", handle, this.encryptedErrorsAddress)).to.be.eq( + targetErrorCode, + ); + expect(await this.encryptedErrors.errorGetCounter()).to.be.eq(BigInt("1")); + }); + + it("errorChangeIf --> false --> no change for error code", async function () { + // False --> no change in errorCode + const condition = false; + const errorCode = 1; + const targetErrorCode = 2; + + const input = this.instances.alice.createEncryptedInput(this.encryptedErrorsAddress, this.signers.alice.address); + const encryptedData = await input.addBool(condition).add8(errorCode).encrypt(); + + await this.encryptedErrors + .connect(this.signers.alice) + .errorChangeIf(encryptedData.handles[0], encryptedData.handles[1], encryptedData.inputProof, targetErrorCode); + + const handle = await this.encryptedErrors.connect(this.signers.alice).errorGetCodeEmitted(0); + expect(await reencryptEuint8(this.signers, this.instances, "alice", handle, this.encryptedErrorsAddress)).to.be.eq( + errorCode, + ); + expect(await this.encryptedErrors.errorGetCounter()).to.be.eq(BigInt("1")); + }); + + it("errorChangeIfNot --> true --> no change for error code", async function () { + // True --> no change errorCode + const condition = true; + const errorCode = 1; + const targetErrorCode = 2; + + const input = this.instances.alice.createEncryptedInput(this.encryptedErrorsAddress, this.signers.alice.address); + const encryptedData = await input.addBool(condition).add8(errorCode).encrypt(); + + await this.encryptedErrors + .connect(this.signers.alice) + .errorChangeIfNot(encryptedData.handles[0], encryptedData.handles[1], encryptedData.inputProof, targetErrorCode); + + const handle = await this.encryptedErrors.connect(this.signers.alice).errorGetCodeEmitted(0); + expect(await reencryptEuint8(this.signers, this.instances, "alice", handle, this.encryptedErrorsAddress)).to.be.eq( + errorCode, + ); + expect(await this.encryptedErrors.errorGetCounter()).to.be.eq(BigInt("1")); + }); + + it("errorChangeIfNot --> false --> change error code", async function () { + // False --> change in errorCode + const condition = false; + const errorCode = 1; + const targetErrorCode = 2; + + const input = this.instances.alice.createEncryptedInput(this.encryptedErrorsAddress, this.signers.alice.address); + const encryptedData = await input.addBool(condition).add8(errorCode).encrypt(); + + await this.encryptedErrors + .connect(this.signers.alice) + .errorChangeIfNot(encryptedData.handles[0], encryptedData.handles[1], encryptedData.inputProof, targetErrorCode); + + const handle = await this.encryptedErrors.connect(this.signers.alice).errorGetCodeEmitted(0); + expect(await reencryptEuint8(this.signers, this.instances, "alice", handle, this.encryptedErrorsAddress)).to.be.eq( + targetErrorCode, + ); + expect(await this.encryptedErrors.errorGetCounter()).to.be.eq(BigInt("1")); + }); + + it("cannot deploy if totalNumberErrorCodes_ == 0", async function () { + const numberErrors = 0; + const contractFactory = await ethers.getContractFactory("TestEncryptedErrors"); + await expect(contractFactory.connect(this.signers.alice).deploy(numberErrors)).to.be.revertedWithCustomError( + this.encryptedErrors, + "TotalNumberErrorCodesEqualToZero", + ); + }); + + it("cannot define errors if indexCode is greater or equal than totalNumberErrorCodes", async function () { + const condition = true; + const targetErrorCode = (await this.encryptedErrors.errorGetNumCodesDefined()) + BigInt(1); + + const input = this.instances.alice.createEncryptedInput(this.encryptedErrorsAddress, this.signers.alice.address); + const encryptedData = await input.addBool(condition).encrypt(); + + await expect( + this.encryptedErrors + .connect(this.signers.alice) + .errorDefineIf(encryptedData.handles[0], encryptedData.inputProof, targetErrorCode), + ).to.be.revertedWithCustomError(this.encryptedErrors, "ErrorIndexInvalid"); + + await expect( + this.encryptedErrors + .connect(this.signers.alice) + .errorDefineIfNot(encryptedData.handles[0], encryptedData.inputProof, targetErrorCode), + ).to.be.revertedWithCustomError(this.encryptedErrors, "ErrorIndexInvalid"); + }); + + it("cannot define errors if indexCode is 0 or equal", async function () { + const condition = true; + const targetErrorCode = 0; + + const input = this.instances.alice.createEncryptedInput(this.encryptedErrorsAddress, this.signers.alice.address); + const encryptedData = await input.addBool(condition).encrypt(); + + await expect( + this.encryptedErrors + .connect(this.signers.alice) + .errorDefineIf(encryptedData.handles[0], encryptedData.inputProof, targetErrorCode), + ).to.be.revertedWithCustomError(this.encryptedErrors, "ErrorIndexIsNull"); + + await expect( + this.encryptedErrors + .connect(this.signers.alice) + .errorDefineIfNot(encryptedData.handles[0], encryptedData.inputProof, targetErrorCode), + ).to.be.revertedWithCustomError(this.encryptedErrors, "ErrorIndexIsNull"); + }); + + it("cannot change errors if indexCode is greater or equal than totalNumberErrorCodes", async function () { + const condition = true; + const errorCode = 1; + const targetErrorCode = (await this.encryptedErrors.errorGetNumCodesDefined()) + BigInt(1); + + const input = this.instances.alice.createEncryptedInput(this.encryptedErrorsAddress, this.signers.alice.address); + const encryptedData = await input.addBool(condition).add8(errorCode).encrypt(); + + await expect( + this.encryptedErrors + .connect(this.signers.alice) + .errorChangeIf(encryptedData.handles[0], encryptedData.handles[1], encryptedData.inputProof, targetErrorCode), + ).to.be.revertedWithCustomError(this.encryptedErrors, "ErrorIndexInvalid"); + + await expect( + this.encryptedErrors + .connect(this.signers.alice) + .errorChangeIfNot( + encryptedData.handles[0], + encryptedData.handles[1], + encryptedData.inputProof, + targetErrorCode, + ), + ).to.be.revertedWithCustomError(this.encryptedErrors, "ErrorIndexInvalid"); + }); + + it("cannot call _errorGetCodeDefinition if indexCode is greater or equal than totalNumberErrorCodes", async function () { + const indexCodeDefinition = await this.encryptedErrors.errorGetNumCodesDefined(); + + await expect( + this.encryptedErrors.connect(this.signers.alice).errorGetCodeDefinition(indexCodeDefinition), + ).to.be.revertedWithCustomError(this.encryptedErrors, "ErrorIndexInvalid"); + }); + + it("cannot call _errorGetCodeEmitted if errorId is greater than errorCounter", async function () { + const errorCounter = await this.encryptedErrors.errorGetCounter(); + + await expect( + this.encryptedErrors.connect(this.signers.alice).errorGetCodeEmitted(errorCounter), + ).to.be.revertedWithCustomError(this.encryptedErrors, "ErrorIndexInvalid"); + }); +});