diff --git a/contracts/contracts/BaseTFHEExecutor.sol b/contracts/contracts/BaseTFHEExecutor.sol new file mode 100644 index 00000000..417e7cf1 --- /dev/null +++ b/contracts/contracts/BaseTFHEExecutor.sol @@ -0,0 +1,605 @@ +// SPDX-License-Identifier: BSD-3-Clause-Clear +pragma solidity ^0.8.24; + +import "./ACL.sol"; +import "./FHEPayment.sol"; +import "../addresses/ACLAddress.sol"; +import "../addresses/FHEPaymentAddress.sol"; +import "../addresses/InputVerifierAddress.sol"; +import "@openzeppelin/contracts/utils/Strings.sol"; +import {Ownable2StepUpgradeable} from "@openzeppelin/contracts-upgradeable/access/Ownable2StepUpgradeable.sol"; + +interface IInputVerifier { + function verifyCiphertext( + BaseTFHEExecutor.ContextUserInputs memory context, + bytes32 inputHandle, + bytes memory inputProof + ) external returns (uint256); +} + +abstract contract BaseTFHEExecutor is Ownable2StepUpgradeable { + /// @notice Handle version + uint8 public constant HANDLE_VERSION = 0; + + /// @notice Name of the contract + string private constant CONTRACT_NAME = "TFHEExecutor"; + + /// @notice Version of the contract + uint256 private constant MAJOR_VERSION = 0; + uint256 private constant MINOR_VERSION = 1; + uint256 private constant PATCH_VERSION = 0; + + ACL private constant acl = ACL(aclAdd); + FHEPayment private constant fhePayment = FHEPayment(fhePaymentAdd); + IInputVerifier private constant inputVerifier = IInputVerifier(inputVerifierAdd); + + /// @custom:storage-location erc7201:fhevm.storage.TFHEExecutor + struct TFHEExecutorStorage { + uint256 counterRand; /// @notice counter used for computing handles of randomness operators + } + + struct ContextUserInputs { + address aclAddress; + address userAddress; + address contractAddress; + } + + // keccak256(abi.encode(uint256(keccak256("fhevm.storage.TFHEExecutor")) - 1)) & ~bytes32(uint256(0xff)) + bytes32 private constant TFHEExecutorStorageLocation = + 0xa436a06f0efce5ea38c956a21e24202a59b3b746d48a23fb52b4a5bc33fe3e00; + + function _getTFHEExecutorStorage() internal pure returns (TFHEExecutorStorage storage $) { + assembly { + $.slot := TFHEExecutorStorageLocation + } + } + + /// @notice Getter function for the ACL contract address + function getACLAddress() public view virtual returns (address) { + return address(acl); + } + + /// @notice Getter function for the FHEPayment contract address + function getFHEPaymentAddress() public view virtual returns (address) { + return address(fhePayment); + } + + /// @notice Getter function for the InputVerifier contract address + function getInputVerifierAddress() public view virtual returns (address) { + return address(inputVerifier); + } + + /// @notice Initializes the contract setting `initialOwner` as the initial owner + function initialize(address initialOwner) external initializer { + __Ownable_init(initialOwner); + } + + enum Operators { + fheAdd, + fheSub, + fheMul, + fheDiv, + fheRem, + fheBitAnd, + fheBitOr, + fheBitXor, + fheShl, + fheShr, + fheRotl, + fheRotr, + fheEq, + fheNe, + fheGe, + fheGt, + fheLe, + fheLt, + fheMin, + fheMax, + fheNeg, + fheNot, + verifyCiphertext, + cast, + trivialEncrypt, + fheIfThenElse, + fheRand, + fheRandBounded + } + + function isPowerOfTwo(uint256 x) internal pure virtual returns (bool) { + return (x > 0) && ((x & (x - 1)) == 0); + } + + /// @dev handle format for user inputs is: keccak256(keccak256(CiphertextFHEList)||index_handle)[0:29] || index_handle || handle_type || handle_version + /// @dev other handles format (fhe ops results) is: keccak256(keccak256(rawCiphertextFHEList)||index_handle)[0:30] || handle_type || handle_version + /// @dev the CiphertextFHEList actually contains: 1 byte (= N) for size of handles_list, N bytes for the handles_types : 1 per handle, then the original fhe160list raw ciphertext + function typeOf(uint256 handle) internal pure virtual returns (uint8) { + uint8 typeCt = uint8(handle >> 8); + return typeCt; + } + + function appendType(uint256 prehandle, uint8 handleType) internal pure virtual returns (uint256 result) { + result = prehandle & 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0000; + result = result | (uint256(handleType) << 8); // append type + result = result | HANDLE_VERSION; + } + + function requireType(uint256 handle, uint256 supportedTypes) internal pure virtual { + uint8 typeCt = typeOf(handle); + require((1 << typeCt) & supportedTypes > 0, "Unsupported type"); + } + + function unaryOp(Operators op, uint256 ct) internal virtual returns (uint256 result) { + require(acl.isAllowed(ct, msg.sender), "Sender doesn't own ct on op"); + result = uint256(keccak256(abi.encodePacked(op, ct, acl, block.chainid))); + uint8 typeCt = typeOf(ct); + result = appendType(result, typeCt); + acl.allowTransient(result, msg.sender); + } + + function binaryOp( + Operators op, + uint256 lhs, + uint256 rhs, + bytes1 scalar, + uint8 resultType + ) internal virtual returns (uint256 result) { + require(acl.isAllowed(lhs, msg.sender), "Sender doesn't own lhs on op"); + if (scalar == 0x00) { + require(acl.isAllowed(rhs, msg.sender), "Sender doesn't own rhs on op"); + uint8 typeRhs = typeOf(rhs); + uint8 typeLhs = typeOf(lhs); + require(typeLhs == typeRhs, "Incompatible types for lhs and rhs"); + } + result = uint256(keccak256(abi.encodePacked(op, lhs, rhs, scalar, acl, block.chainid))); + result = appendType(result, resultType); + acl.allowTransient(result, msg.sender); + } + + function ternaryOp( + Operators op, + uint256 lhs, + uint256 middle, + uint256 rhs + ) internal virtual returns (uint256 result) { + require(acl.isAllowed(lhs, msg.sender), "Sender doesn't own lhs on op"); + require(acl.isAllowed(middle, msg.sender), "Sender doesn't own middle on op"); + require(acl.isAllowed(rhs, msg.sender), "Sender doesn't own rhs on op"); + uint8 typeLhs = typeOf(lhs); + uint8 typeMiddle = typeOf(middle); + uint8 typeRhs = typeOf(rhs); + require(typeLhs == 0, "Unsupported type for lhs"); // lhs must be ebool + require(typeMiddle == typeRhs, "Incompatible types for middle and rhs"); + result = uint256(keccak256(abi.encodePacked(op, lhs, middle, rhs, acl, block.chainid))); + result = appendType(result, typeMiddle); + acl.allowTransient(result, msg.sender); + } + + function _fheAdd(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheAdd(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheAdd, lhs, rhs, scalar, lhsType); + } + + function _fheSub(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheSub(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheSub, lhs, rhs, scalar, lhsType); + } + + function _fheMul(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheMul(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheMul, lhs, rhs, scalar, lhsType); + } + + function _fheDiv(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + require(scalarByte & 0x01 == 0x01, "Only fheDiv by a scalar is supported"); + require(rhs != 0, "Could not divide by 0"); + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheDiv(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheDiv, lhs, rhs, scalar, lhsType); + } + + function _fheRem(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + require(scalarByte & 0x01 == 0x01, "Only fheRem by a scalar is supported"); + require(rhs != 0, "Could not divide by 0"); + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheRem(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheRem, lhs, rhs, scalar, lhsType); + } + + function _fheBitAnd(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 0) + (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheBitAnd(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheBitAnd, lhs, rhs, scalar, lhsType); + } + + function _fheBitOr(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 0) + (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheBitOr(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheBitOr, lhs, rhs, scalar, lhsType); + } + + function _fheBitXor(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 0) + (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheBitXor(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheBitXor, lhs, rhs, scalar, lhsType); + } + + function _fheShl(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheShl(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheShl, lhs, rhs, scalar, lhsType); + } + + function _fheShr(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheShr(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheShr, lhs, rhs, scalar, lhsType); + } + + function _fheRotl(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheRotl(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheRotl, lhs, rhs, scalar, lhsType); + } + + function _fheRotr(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheRotr(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheRotr, lhs, rhs, scalar, lhsType); + } + + function _fheEq(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 0) + + (1 << 1) + + (1 << 2) + + (1 << 3) + + (1 << 4) + + (1 << 5) + + (1 << 6) + + (1 << 7) + + (1 << 8) + + (1 << 9) + + (1 << 10) + + (1 << 11); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + if (scalar == 0x01) { + require(lhsType <= 8, "Scalar fheEq for ebytesXXX types must use the overloaded fheEq"); + } + fhePayment.payForFheEq(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheEq, lhs, rhs, scalar, 0); + } + + function _fheEq(uint256 lhs, bytes memory rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 9) + (1 << 10) + (1 << 11); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + require(scalar == 0x01, "Overloaded fheEq is only for scalar ebytesXXX second operand"); + fhePayment.payForFheEq(msg.sender, lhsType, scalar); + require(acl.isAllowed(lhs, msg.sender), "Sender doesn't own lhs on op"); + uint256 lenBytesPT = rhs.length; + if (lhsType == 9) { + require(lenBytesPT == 64, "Bytes array length of Bytes64 should be 64"); + } else if (lhsType == 10) { + require(lenBytesPT == 128, "Bytes array length of Bytes128 should be 128"); + } else { + // @note: i.e lhsType == 11 thanks to the first pre-condition + require(lenBytesPT == 256, "Bytes array length of Bytes256 should be 256"); + } + result = uint256(keccak256(abi.encodePacked(Operators.fheEq, lhs, rhs, scalar, acl, block.chainid))); + result = appendType(result, 0); + acl.allowTransient(result, msg.sender); + } + + function _fheNe(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 0) + + (1 << 1) + + (1 << 2) + + (1 << 3) + + (1 << 4) + + (1 << 5) + + (1 << 6) + + (1 << 7) + + (1 << 8) + + (1 << 9) + + (1 << 10) + + (1 << 11); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + if (scalar == 0x01) { + require(lhsType <= 8, "Scalar fheNe for ebytesXXX types must use the overloaded fheNe"); + } + fhePayment.payForFheNe(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheNe, lhs, rhs, scalar, 0); + } + + function _fheNe(uint256 lhs, bytes memory rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 9) + (1 << 10) + (1 << 11); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + require(scalar == 0x01, "Overloaded fheNe is only for scalar ebytesXXX second operand"); + fhePayment.payForFheNe(msg.sender, lhsType, scalar); + require(acl.isAllowed(lhs, msg.sender), "Sender doesn't own lhs on op"); + uint256 lenBytesPT = rhs.length; + if (lhsType == 9) { + require(lenBytesPT == 64, "Bytes array length of Bytes64 should be 64"); + } else if (lhsType == 10) { + require(lenBytesPT == 128, "Bytes array length of Bytes128 should be 128"); + } else { + // @note: i.e lhsType == 11 thanks to the first pre-condition + require(lenBytesPT == 256, "Bytes array length of Bytes256 should be 256"); + } + result = uint256(keccak256(abi.encodePacked(Operators.fheNe, lhs, rhs, scalar, acl, block.chainid))); + result = appendType(result, 0); + acl.allowTransient(result, msg.sender); + } + + function _fheGe(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheGe(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheGe, lhs, rhs, scalar, 0); + } + + function _fheGt(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheGt(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheGt, lhs, rhs, scalar, 0); + } + + function _fheLe(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheLe(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheLe, lhs, rhs, scalar, 0); + } + + function _fheLt(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheLt(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheLt, lhs, rhs, scalar, 0); + } + + function _fheMin(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheMin(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheMin, lhs, rhs, scalar, lhsType); + } + + function _fheMax(uint256 lhs, uint256 rhs, bytes1 scalarByte) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(lhs, supportedTypes); + uint8 lhsType = typeOf(lhs); + bytes1 scalar = scalarByte & 0x01; + fhePayment.payForFheMax(msg.sender, lhsType, scalar); + result = binaryOp(Operators.fheMax, lhs, rhs, scalar, lhsType); + } + + function _fheNeg(uint256 ct) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(ct, supportedTypes); + uint8 typeCt = typeOf(ct); + fhePayment.payForFheNeg(msg.sender, typeCt); + result = unaryOp(Operators.fheNeg, ct); + } + + function _fheNot(uint256 ct) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 0) + (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + requireType(ct, supportedTypes); + uint8 typeCt = typeOf(ct); + fhePayment.payForFheNot(msg.sender, typeCt); + result = unaryOp(Operators.fheNot, ct); + } + + function _verifyCiphertext( + bytes32 inputHandle, + address userAddress, + bytes memory inputProof, + bytes1 inputType + ) internal virtual returns (uint256 result) { + ContextUserInputs memory contextUserInputs = ContextUserInputs({ + aclAddress: address(acl), + userAddress: userAddress, + contractAddress: msg.sender + }); + uint8 typeCt = typeOf(uint256(inputHandle)); + require(uint8(inputType) == typeCt, "Wrong type"); + result = inputVerifier.verifyCiphertext(contextUserInputs, inputHandle, inputProof); + acl.allowTransient(result, msg.sender); + } + + function _cast(uint256 ct, bytes1 toType) internal virtual returns (uint256 result) { + require(acl.isAllowed(ct, msg.sender), "Sender doesn't own ct on cast"); + uint256 supportedTypesInput = (1 << 0) + + (1 << 1) + + (1 << 2) + + (1 << 3) + + (1 << 4) + + (1 << 5) + + (1 << 6) + + (1 << 8); + requireType(ct, supportedTypesInput); + uint256 supportedTypesOutput = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); // @note: unsupported casting to ebool (use fheNe instead) + require((1 << uint8(toType)) & supportedTypesOutput > 0, "Unsupported output type"); + uint8 typeCt = typeOf(ct); + require(bytes1(typeCt) != toType, "Cannot cast to same type"); + fhePayment.payForCast(msg.sender, typeCt); + result = uint256(keccak256(abi.encodePacked(Operators.cast, ct, toType, acl, block.chainid))); + result = appendType(result, uint8(toType)); + acl.allowTransient(result, msg.sender); + } + + function _trivialEncrypt(uint256 pt, bytes1 toType) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 0) + + (1 << 1) + + (1 << 2) + + (1 << 3) + + (1 << 4) + + (1 << 5) + + (1 << 6) + + (1 << 7) + + (1 << 8); + uint8 toT = uint8(toType); + require((1 << toT) & supportedTypes > 0, "Unsupported type"); + fhePayment.payForTrivialEncrypt(msg.sender, toT); + result = uint256(keccak256(abi.encodePacked(Operators.trivialEncrypt, pt, toType, acl, block.chainid))); + result = appendType(result, toT); + acl.allowTransient(result, msg.sender); + } + + function _trivialEncrypt(bytes memory pt, bytes1 toType) internal virtual returns (uint256 result) { + // @note: overloaded function for ebytesXX types + uint256 supportedTypes = (1 << 9) + (1 << 10) + (1 << 11); + uint8 toT = uint8(toType); + require((1 << toT) & supportedTypes > 0, "Unsupported type"); + fhePayment.payForTrivialEncrypt(msg.sender, toT); + uint256 lenBytesPT = pt.length; + if (toT == 9) { + require(lenBytesPT == 64, "Bytes array length of Bytes64 should be 64"); + } else if (toT == 10) { + require(lenBytesPT == 128, "Bytes array length of Bytes128 should be 128"); + } else { + // @note: i.e toT == 11 thanks to the pre-condition above + require(lenBytesPT == 256, "Bytes array length of Bytes256 should be 256"); + } + result = uint256(keccak256(abi.encodePacked(Operators.trivialEncrypt, pt, toType, acl, block.chainid))); + result = appendType(result, toT); + acl.allowTransient(result, msg.sender); + } + + function _fheIfThenElse( + uint256 control, + uint256 ifTrue, + uint256 ifFalse + ) internal virtual returns (uint256 result) { + uint256 supportedTypes = (1 << 0) + + (1 << 1) + + (1 << 2) + + (1 << 3) + + (1 << 4) + + (1 << 5) + + (1 << 6) + + (1 << 7) + + (1 << 8) + + (1 << 9) + + (1 << 10) + + (1 << 11); + requireType(ifTrue, supportedTypes); + uint8 typeCt = typeOf(ifTrue); + fhePayment.payForIfThenElse(msg.sender, typeCt); + result = ternaryOp(Operators.fheIfThenElse, control, ifTrue, ifFalse); + } + + function _fheRand(bytes1 randType) internal virtual returns (uint256 result) { + TFHEExecutorStorage storage $ = _getTFHEExecutorStorage(); + uint256 supportedTypes = (1 << 0) + + (1 << 1) + + (1 << 2) + + (1 << 3) + + (1 << 4) + + (1 << 5) + + (1 << 6) + + (1 << 8) + + (1 << 9) + + (1 << 10) + + (1 << 11); + uint8 randT = uint8(randType); + require((1 << randT) & supportedTypes > 0, "Unsupported erandom type"); + fhePayment.payForFheRand(msg.sender, randT); + bytes16 seed = bytes16( + keccak256(abi.encodePacked($.counterRand, acl, block.chainid, blockhash(block.number - 1), block.timestamp)) + ); + result = uint256(keccak256(abi.encodePacked(Operators.fheRand, randType, seed))); + result = appendType(result, randT); + acl.allowTransient(result, msg.sender); + $.counterRand++; + } + + function _fheRandBounded(uint256 upperBound, bytes1 randType) internal virtual returns (uint256 result) { + TFHEExecutorStorage storage $ = _getTFHEExecutorStorage(); + uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); + uint8 randT = uint8(randType); + require((1 << randT) & supportedTypes > 0, "Unsupported erandom type"); + require(isPowerOfTwo(upperBound), "UpperBound must be a power of 2"); + fhePayment.payForFheRandBounded(msg.sender, randT); + bytes16 seed = bytes16( + keccak256(abi.encodePacked($.counterRand, acl, block.chainid, blockhash(block.number - 1), block.timestamp)) + ); + result = uint256(keccak256(abi.encodePacked(Operators.fheRandBounded, upperBound, randType, seed))); + result = appendType(result, randT); + acl.allowTransient(result, msg.sender); + $.counterRand++; + } + + /// @notice Getter for the name and version of the contract + /// @return string representing the name and the version of the contract + function getVersion() external pure virtual returns (string memory) { + return + string( + abi.encodePacked( + CONTRACT_NAME, + " v", + Strings.toString(MAJOR_VERSION), + ".", + Strings.toString(MINOR_VERSION), + ".", + Strings.toString(PATCH_VERSION) + ) + ); + } +} diff --git a/contracts/contracts/TFHEExecutor.events.sol b/contracts/contracts/TFHEExecutor.events.sol index e930cc7c..cbdd63be 100644 --- a/contracts/contracts/TFHEExecutor.events.sol +++ b/contracts/contracts/TFHEExecutor.events.sol @@ -1,25 +1,10 @@ // SPDX-License-Identifier: BSD-3-Clause-Clear - pragma solidity ^0.8.24; -import "./ACL.sol"; -import "./FHEPayment.sol"; -import "../addresses/ACLAddress.sol"; -import "../addresses/FHEPaymentAddress.sol"; -import "../addresses/InputVerifierAddress.sol"; -import "@openzeppelin/contracts/utils/Strings.sol"; -import "@openzeppelin/contracts-upgradeable/utils/cryptography/EIP712Upgradeable.sol"; -import "@openzeppelin/contracts-upgradeable/access/Ownable2StepUpgradeable.sol"; +import {BaseTFHEExecutor} from "./BaseTFHEExecutor.sol"; +import {UUPSUpgradeable} from "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol"; -interface IInputVerifier { - function verifyCiphertext( - TFHEExecutor.ContextUserInputs memory context, - bytes32 inputHandle, - bytes memory inputProof - ) external returns (uint256); -} - -contract TFHEExecutor is UUPSUpgradeable, Ownable2StepUpgradeable { +contract TFHEExecutor is BaseTFHEExecutor, UUPSUpgradeable { event FheAdd(uint256 lhs, uint256 rhs, bytes1 scalarByte, uint256 result); event FheSub(uint256 lhs, uint256 rhs, bytes1 scalarByte, uint256 result); event FheMul(uint256 lhs, uint256 rhs, bytes1 scalarByte, uint256 result); @@ -58,462 +43,128 @@ contract TFHEExecutor is UUPSUpgradeable, Ownable2StepUpgradeable { event FheRand(bytes1 randType, uint256 result); event FheRandBounded(uint256 upperBound, bytes1 randType, uint256 result); - /// @notice Handle version - uint8 public constant HANDLE_VERSION = 0; - - /// @notice Name of the contract - string private constant CONTRACT_NAME = "TFHEExecutor"; - - /// @notice Version of the contract - uint256 private constant MAJOR_VERSION = 0; - uint256 private constant MINOR_VERSION = 1; - uint256 private constant PATCH_VERSION = 0; - - ACL private constant acl = ACL(aclAdd); - FHEPayment private constant fhePayment = FHEPayment(fhePaymentAdd); - IInputVerifier private constant inputVerifier = IInputVerifier(inputVerifierAdd); - - /// @custom:storage-location erc7201:fhevm.storage.TFHEExecutor - struct TFHEExecutorStorage { - uint256 counterRand; /// @notice counter used for computing handles of randomness operators - } - - struct ContextUserInputs { - address aclAddress; - address userAddress; - address contractAddress; - } - - // keccak256(abi.encode(uint256(keccak256("fhevm.storage.TFHEExecutor")) - 1)) & ~bytes32(uint256(0xff)) - bytes32 private constant TFHEExecutorStorageLocation = - 0xa436a06f0efce5ea38c956a21e24202a59b3b746d48a23fb52b4a5bc33fe3e00; - - function _getTFHEExecutorStorage() internal pure returns (TFHEExecutorStorage storage $) { - assembly { - $.slot := TFHEExecutorStorageLocation - } - } - - function _authorizeUpgrade(address _newImplementation) internal virtual override onlyOwner {} - - /// @notice Getter function for the ACL contract address - function getACLAddress() public view virtual returns (address) { - return address(acl); - } - - /// @notice Getter function for the FHEPayment contract address - function getFHEPaymentAddress() public view virtual returns (address) { - return address(fhePayment); - } - - /// @notice Getter function for the InputVerifier contract address - function getInputVerifierAddress() public view virtual returns (address) { - return address(inputVerifier); - } - /// @custom:oz-upgrades-unsafe-allow constructor constructor() { _disableInitializers(); } - /// @notice Initializes the contract setting `initialOwner` as the initial owner - function initialize(address initialOwner) external initializer { - __Ownable_init(initialOwner); - } - - enum Operators { - fheAdd, - fheSub, - fheMul, - fheDiv, - fheRem, - fheBitAnd, - fheBitOr, - fheBitXor, - fheShl, - fheShr, - fheRotl, - fheRotr, - fheEq, - fheNe, - fheGe, - fheGt, - fheLe, - fheLt, - fheMin, - fheMax, - fheNeg, - fheNot, - verifyCiphertext, - cast, - trivialEncrypt, - fheIfThenElse, - fheRand, - fheRandBounded - } - - function isPowerOfTwo(uint256 x) internal pure virtual returns (bool) { - return (x > 0) && ((x & (x - 1)) == 0); - } - - /// @dev handle format for user inputs is: keccak256(keccak256(CiphertextFHEList)||index_handle)[0:29] || index_handle || handle_type || handle_version - /// @dev other handles format (fhe ops results) is: keccak256(keccak256(rawCiphertextFHEList)||index_handle)[0:30] || handle_type || handle_version - /// @dev the CiphertextFHEList actually contains: 1 byte (= N) for size of handles_list, N bytes for the handles_types : 1 per handle, then the original fhe160list raw ciphertext - function typeOf(uint256 handle) internal pure virtual returns (uint8) { - uint8 typeCt = uint8(handle >> 8); - return typeCt; - } - - function appendType(uint256 prehandle, uint8 handleType) internal pure virtual returns (uint256 result) { - result = prehandle & 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0000; - result = result | (uint256(handleType) << 8); // append type - result = result | HANDLE_VERSION; - } - - function requireType(uint256 handle, uint256 supportedTypes) internal pure virtual { - uint8 typeCt = typeOf(handle); - require((1 << typeCt) & supportedTypes > 0, "Unsupported type"); - } - - function unaryOp(Operators op, uint256 ct) internal virtual returns (uint256 result) { - require(acl.isAllowed(ct, msg.sender), "Sender doesn't own ct on op"); - result = uint256(keccak256(abi.encodePacked(op, ct, acl, block.chainid))); - uint8 typeCt = typeOf(ct); - result = appendType(result, typeCt); - acl.allowTransient(result, msg.sender); - } - - function binaryOp( - Operators op, - uint256 lhs, - uint256 rhs, - bytes1 scalar, - uint8 resultType - ) internal virtual returns (uint256 result) { - require(acl.isAllowed(lhs, msg.sender), "Sender doesn't own lhs on op"); - if (scalar == 0x00) { - require(acl.isAllowed(rhs, msg.sender), "Sender doesn't own rhs on op"); - uint8 typeRhs = typeOf(rhs); - uint8 typeLhs = typeOf(lhs); - require(typeLhs == typeRhs, "Incompatible types for lhs and rhs"); - } - result = uint256(keccak256(abi.encodePacked(op, lhs, rhs, scalar, acl, block.chainid))); - result = appendType(result, resultType); - acl.allowTransient(result, msg.sender); - } - - function ternaryOp( - Operators op, - uint256 lhs, - uint256 middle, - uint256 rhs - ) internal virtual returns (uint256 result) { - require(acl.isAllowed(lhs, msg.sender), "Sender doesn't own lhs on op"); - require(acl.isAllowed(middle, msg.sender), "Sender doesn't own middle on op"); - require(acl.isAllowed(rhs, msg.sender), "Sender doesn't own rhs on op"); - uint8 typeLhs = typeOf(lhs); - uint8 typeMiddle = typeOf(middle); - uint8 typeRhs = typeOf(rhs); - require(typeLhs == 0, "Unsupported type for lhs"); // lhs must be ebool - require(typeMiddle == typeRhs, "Incompatible types for middle and rhs"); - result = uint256(keccak256(abi.encodePacked(op, lhs, middle, rhs, acl, block.chainid))); - result = appendType(result, typeMiddle); - acl.allowTransient(result, msg.sender); - } - function fheAdd(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheAdd(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheAdd, lhs, rhs, scalar, lhsType); + result = _fheAdd(lhs, rhs, scalarByte); emit FheAdd(lhs, rhs, scalarByte, result); } function fheSub(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheSub(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheSub, lhs, rhs, scalar, lhsType); + result = _fheSub(lhs, rhs, scalarByte); emit FheSub(lhs, rhs, scalarByte, result); } function fheMul(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheMul(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheMul, lhs, rhs, scalar, lhsType); + result = _fheMul(lhs, rhs, scalarByte); emit FheMul(lhs, rhs, scalarByte, result); } function fheDiv(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - require(scalarByte & 0x01 == 0x01, "Only fheDiv by a scalar is supported"); - require(rhs != 0, "Could not divide by 0"); - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheDiv(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheDiv, lhs, rhs, scalar, lhsType); + result = _fheDiv(lhs, rhs, scalarByte); emit FheDiv(lhs, rhs, scalarByte, result); } function fheRem(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - require(scalarByte & 0x01 == 0x01, "Only fheRem by a scalar is supported"); - require(rhs != 0, "Could not divide by 0"); - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheRem(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheRem, lhs, rhs, scalar, lhsType); + result = _fheRem(lhs, rhs, scalarByte); emit FheRem(lhs, rhs, scalarByte, result); } function fheBitAnd(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheBitAnd(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheBitAnd, lhs, rhs, scalar, lhsType); + result = _fheBitAnd(lhs, rhs, scalarByte); emit FheBitAnd(lhs, rhs, scalarByte, result); } function fheBitOr(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheBitOr(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheBitOr, lhs, rhs, scalar, lhsType); + result = _fheBitOr(lhs, rhs, scalarByte); emit FheBitOr(lhs, rhs, scalarByte, result); } function fheBitXor(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheBitXor(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheBitXor, lhs, rhs, scalar, lhsType); + result = _fheBitXor(lhs, rhs, scalarByte); emit FheBitXor(lhs, rhs, scalarByte, result); } function fheShl(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheShl(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheShl, lhs, rhs, scalar, lhsType); + result = _fheShl(lhs, rhs, scalarByte); emit FheShl(lhs, rhs, scalarByte, result); } function fheShr(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheShr(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheShr, lhs, rhs, scalar, lhsType); + result = _fheShr(lhs, rhs, scalarByte); emit FheShr(lhs, rhs, scalarByte, result); } function fheRotl(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheRotl(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheRotl, lhs, rhs, scalar, lhsType); + result = _fheRotl(lhs, rhs, scalarByte); emit FheRotl(lhs, rhs, scalarByte, result); } - function fheRotr(uint256 lhs, uint256 rhs, bytes1 scalarByte) external returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheRotr(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheRotr, lhs, rhs, scalar, lhsType); + function fheRotr(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { + result = _fheRotr(lhs, rhs, scalarByte); emit FheRotr(lhs, rhs, scalarByte, result); } function fheEq(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + - (1 << 1) + - (1 << 2) + - (1 << 3) + - (1 << 4) + - (1 << 5) + - (1 << 6) + - (1 << 7) + - (1 << 8) + - (1 << 9) + - (1 << 10) + - (1 << 11); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - if (scalar == 0x01) { - require(lhsType <= 8, "Scalar fheEq for ebytesXXX types must use the overloaded fheEq"); - } - fhePayment.payForFheEq(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheEq, lhs, rhs, scalar, 0); + result = _fheEq(lhs, rhs, scalarByte); emit FheEq(lhs, rhs, scalarByte, result); } function fheEq(uint256 lhs, bytes memory rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 9) + (1 << 10) + (1 << 11); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - require(scalar == 0x01, "Overloaded fheEq is only for scalar ebytesXXX second operand"); - fhePayment.payForFheEq(msg.sender, lhsType, scalar); - require(acl.isAllowed(lhs, msg.sender), "Sender doesn't own lhs on op"); - uint256 lenBytesPT = rhs.length; - if (lhsType == 9) { - require(lenBytesPT == 64, "Bytes array length of Bytes64 should be 64"); - } else if (lhsType == 10) { - require(lenBytesPT == 128, "Bytes array length of Bytes128 should be 128"); - } else { - // @note: i.e lhsType == 11 thanks to the first pre-condition - require(lenBytesPT == 256, "Bytes array length of Bytes256 should be 256"); - } - result = uint256(keccak256(abi.encodePacked(Operators.fheEq, lhs, rhs, scalar, acl, block.chainid))); - result = appendType(result, 0); - acl.allowTransient(result, msg.sender); + result = _fheEq(lhs, rhs, scalarByte); emit FheEqBytes(lhs, rhs, scalarByte, result); } function fheNe(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + - (1 << 1) + - (1 << 2) + - (1 << 3) + - (1 << 4) + - (1 << 5) + - (1 << 6) + - (1 << 7) + - (1 << 8) + - (1 << 9) + - (1 << 10) + - (1 << 11); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - if (scalar == 0x01) { - require(lhsType <= 8, "Scalar fheNe for ebytesXXX types must use the overloaded fheNe"); - } - fhePayment.payForFheNe(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheNe, lhs, rhs, scalar, 0); + result = _fheNe(lhs, rhs, scalarByte); emit FheNe(lhs, rhs, scalarByte, result); } function fheNe(uint256 lhs, bytes memory rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 9) + (1 << 10) + (1 << 11); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - require(scalar == 0x01, "Overloaded fheNe is only for scalar ebytesXXX second operand"); - fhePayment.payForFheNe(msg.sender, lhsType, scalar); - require(acl.isAllowed(lhs, msg.sender), "Sender doesn't own lhs on op"); - uint256 lenBytesPT = rhs.length; - if (lhsType == 9) { - require(lenBytesPT == 64, "Bytes array length of Bytes64 should be 64"); - } else if (lhsType == 10) { - require(lenBytesPT == 128, "Bytes array length of Bytes128 should be 128"); - } else { - // @note: i.e lhsType == 11 thanks to the first pre-condition - require(lenBytesPT == 256, "Bytes array length of Bytes256 should be 256"); - } - result = uint256(keccak256(abi.encodePacked(Operators.fheNe, lhs, rhs, scalar, acl, block.chainid))); - result = appendType(result, 0); - acl.allowTransient(result, msg.sender); + result = _fheNe(lhs, rhs, scalarByte); emit FheNeBytes(lhs, rhs, scalarByte, result); } function fheGe(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheGe(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheGe, lhs, rhs, scalar, 0); + result = _fheGe(lhs, rhs, scalarByte); emit FheGe(lhs, rhs, scalarByte, result); } function fheGt(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheGt(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheGt, lhs, rhs, scalar, 0); + result = _fheGt(lhs, rhs, scalarByte); emit FheGt(lhs, rhs, scalarByte, result); } function fheLe(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheLe(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheLe, lhs, rhs, scalar, 0); + result = _fheLe(lhs, rhs, scalarByte); emit FheLe(lhs, rhs, scalarByte, result); } function fheLt(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheLt(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheLt, lhs, rhs, scalar, 0); + result = _fheLt(lhs, rhs, scalarByte); emit FheLt(lhs, rhs, scalarByte, result); } function fheMin(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheMin(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheMin, lhs, rhs, scalar, lhsType); + result = _fheMin(lhs, rhs, scalarByte); emit FheMin(lhs, rhs, scalarByte, result); } function fheMax(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheMax(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheMax, lhs, rhs, scalar, lhsType); + result = _fheMax(lhs, rhs, scalarByte); emit FheMax(lhs, rhs, scalarByte, result); } function fheNeg(uint256 ct) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(ct, supportedTypes); - uint8 typeCt = typeOf(ct); - fhePayment.payForFheNeg(msg.sender, typeCt); - result = unaryOp(Operators.fheNeg, ct); + result = _fheNeg(ct); emit FheNeg(ct, result); } function fheNot(uint256 ct) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(ct, supportedTypes); - uint8 typeCt = typeOf(ct); - fhePayment.payForFheNot(msg.sender, typeCt); - result = unaryOp(Operators.fheNot, ct); + result = _fheNot(ct); emit FheNot(ct, result); } @@ -523,157 +174,39 @@ contract TFHEExecutor is UUPSUpgradeable, Ownable2StepUpgradeable { bytes memory inputProof, bytes1 inputType ) external virtual returns (uint256 result) { - ContextUserInputs memory contextUserInputs = ContextUserInputs({ - aclAddress: address(acl), - userAddress: userAddress, - contractAddress: msg.sender - }); - uint8 typeCt = typeOf(uint256(inputHandle)); - require(uint8(inputType) == typeCt, "Wrong type"); - result = inputVerifier.verifyCiphertext(contextUserInputs, inputHandle, inputProof); - acl.allowTransient(result, msg.sender); + result = _verifyCiphertext(inputHandle, userAddress, inputProof, inputType); emit VerifyCiphertext(inputHandle, userAddress, inputProof, inputType, result); } function cast(uint256 ct, bytes1 toType) external virtual returns (uint256 result) { - require(acl.isAllowed(ct, msg.sender), "Sender doesn't own ct on cast"); - uint256 supportedTypesInput = (1 << 0) + - (1 << 1) + - (1 << 2) + - (1 << 3) + - (1 << 4) + - (1 << 5) + - (1 << 6) + - (1 << 8); - requireType(ct, supportedTypesInput); - uint256 supportedTypesOutput = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); // @note: unsupported casting to ebool (use fheNe instead) - require((1 << uint8(toType)) & supportedTypesOutput > 0, "Unsupported output type"); - uint8 typeCt = typeOf(ct); - require(bytes1(typeCt) != toType, "Cannot cast to same type"); - fhePayment.payForCast(msg.sender, typeCt); - result = uint256(keccak256(abi.encodePacked(Operators.cast, ct, toType, acl, block.chainid))); - result = appendType(result, uint8(toType)); - acl.allowTransient(result, msg.sender); + result = _cast(ct, toType); emit Cast(ct, toType, result); } function trivialEncrypt(uint256 pt, bytes1 toType) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + - (1 << 1) + - (1 << 2) + - (1 << 3) + - (1 << 4) + - (1 << 5) + - (1 << 6) + - (1 << 7) + - (1 << 8); - uint8 toT = uint8(toType); - require((1 << toT) & supportedTypes > 0, "Unsupported type"); - fhePayment.payForTrivialEncrypt(msg.sender, toT); - result = uint256(keccak256(abi.encodePacked(Operators.trivialEncrypt, pt, toType, acl, block.chainid))); - result = appendType(result, toT); - acl.allowTransient(result, msg.sender); + result = _trivialEncrypt(pt, toType); emit TrivialEncrypt(pt, toType, result); } function trivialEncrypt(bytes memory pt, bytes1 toType) external virtual returns (uint256 result) { - // @note: overloaded function for ebytesXX types - uint256 supportedTypes = (1 << 9) + (1 << 10) + (1 << 11); - uint8 toT = uint8(toType); - require((1 << toT) & supportedTypes > 0, "Unsupported type"); - fhePayment.payForTrivialEncrypt(msg.sender, toT); - uint256 lenBytesPT = pt.length; - if (toT == 9) { - require(lenBytesPT == 64, "Bytes array length of Bytes64 should be 64"); - } else if (toT == 10) { - require(lenBytesPT == 128, "Bytes array length of Bytes128 should be 128"); - } else { - // @note: i.e toT == 11 thanks to the pre-condition above - require(lenBytesPT == 256, "Bytes array length of Bytes256 should be 256"); - } - result = uint256(keccak256(abi.encodePacked(Operators.trivialEncrypt, pt, toType, acl, block.chainid))); - result = appendType(result, toT); - acl.allowTransient(result, msg.sender); + result = _trivialEncrypt(pt, toType); emit TrivialEncryptBytes(pt, toType, result); } function fheIfThenElse(uint256 control, uint256 ifTrue, uint256 ifFalse) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + - (1 << 1) + - (1 << 2) + - (1 << 3) + - (1 << 4) + - (1 << 5) + - (1 << 6) + - (1 << 7) + - (1 << 8) + - (1 << 9) + - (1 << 10) + - (1 << 11); - requireType(ifTrue, supportedTypes); - uint8 typeCt = typeOf(ifTrue); - fhePayment.payForIfThenElse(msg.sender, typeCt); - result = ternaryOp(Operators.fheIfThenElse, control, ifTrue, ifFalse); + result = _fheIfThenElse(control, ifTrue, ifFalse); emit FheIfThenElse(control, ifTrue, ifFalse, result); } function fheRand(bytes1 randType) external virtual returns (uint256 result) { - TFHEExecutorStorage storage $ = _getTFHEExecutorStorage(); - uint256 supportedTypes = (1 << 0) + - (1 << 1) + - (1 << 2) + - (1 << 3) + - (1 << 4) + - (1 << 5) + - (1 << 6) + - (1 << 8) + - (1 << 9) + - (1 << 10) + - (1 << 11); - uint8 randT = uint8(randType); - require((1 << randT) & supportedTypes > 0, "Unsupported erandom type"); - fhePayment.payForFheRand(msg.sender, randT); - bytes16 seed = bytes16( - keccak256(abi.encodePacked($.counterRand, acl, block.chainid, blockhash(block.number - 1), block.timestamp)) - ); - result = uint256(keccak256(abi.encodePacked(Operators.fheRand, randType, seed))); - result = appendType(result, randT); - acl.allowTransient(result, msg.sender); - $.counterRand++; + result = _fheRand(randType); emit FheRand(randType, result); } function fheRandBounded(uint256 upperBound, bytes1 randType) external virtual returns (uint256 result) { - TFHEExecutorStorage storage $ = _getTFHEExecutorStorage(); - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - uint8 randT = uint8(randType); - require((1 << randT) & supportedTypes > 0, "Unsupported erandom type"); - require(isPowerOfTwo(upperBound), "UpperBound must be a power of 2"); - fhePayment.payForFheRandBounded(msg.sender, randT); - bytes16 seed = bytes16( - keccak256(abi.encodePacked($.counterRand, acl, block.chainid, blockhash(block.number - 1), block.timestamp)) - ); - result = uint256(keccak256(abi.encodePacked(Operators.fheRandBounded, upperBound, randType, seed))); - result = appendType(result, randT); - acl.allowTransient(result, msg.sender); - $.counterRand++; + result = _fheRandBounded(upperBound, randType); emit FheRandBounded(upperBound, randType, result); } - /// @notice Getter for the name and version of the contract - /// @return string representing the name and the version of the contract - function getVersion() external pure virtual returns (string memory) { - return - string( - abi.encodePacked( - CONTRACT_NAME, - " v", - Strings.toString(MAJOR_VERSION), - ".", - Strings.toString(MINOR_VERSION), - ".", - Strings.toString(PATCH_VERSION) - ) - ); - } + function _authorizeUpgrade(address _newImplementation) internal virtual override onlyOwner {} } diff --git a/contracts/contracts/TFHEExecutor.sol b/contracts/contracts/TFHEExecutor.sol index 33eca574..4c2b18fe 100644 --- a/contracts/contracts/TFHEExecutor.sol +++ b/contracts/contracts/TFHEExecutor.sol @@ -1,458 +1,109 @@ // SPDX-License-Identifier: BSD-3-Clause-Clear - pragma solidity ^0.8.24; -import "./ACL.sol"; -import "./FHEPayment.sol"; -import "../addresses/ACLAddress.sol"; -import "../addresses/FHEPaymentAddress.sol"; -import "../addresses/InputVerifierAddress.sol"; -import "@openzeppelin/contracts/utils/Strings.sol"; -import "@openzeppelin/contracts-upgradeable/utils/cryptography/EIP712Upgradeable.sol"; -import "@openzeppelin/contracts-upgradeable/access/Ownable2StepUpgradeable.sol"; - -interface IInputVerifier { - function verifyCiphertext( - TFHEExecutor.ContextUserInputs memory context, - bytes32 inputHandle, - bytes memory inputProof - ) external returns (uint256); -} - -contract TFHEExecutor is UUPSUpgradeable, Ownable2StepUpgradeable { - /// @notice Handle version - uint8 public constant HANDLE_VERSION = 0; - - /// @notice Name of the contract - string private constant CONTRACT_NAME = "TFHEExecutor"; - - /// @notice Version of the contract - uint256 private constant MAJOR_VERSION = 0; - uint256 private constant MINOR_VERSION = 1; - uint256 private constant PATCH_VERSION = 0; - - ACL private constant acl = ACL(aclAdd); - FHEPayment private constant fhePayment = FHEPayment(fhePaymentAdd); - IInputVerifier private constant inputVerifier = IInputVerifier(inputVerifierAdd); - - /// @custom:storage-location erc7201:fhevm.storage.TFHEExecutor - struct TFHEExecutorStorage { - uint256 counterRand; /// @notice counter used for computing handles of randomness operators - } - - struct ContextUserInputs { - address aclAddress; - address userAddress; - address contractAddress; - } - - // keccak256(abi.encode(uint256(keccak256("fhevm.storage.TFHEExecutor")) - 1)) & ~bytes32(uint256(0xff)) - bytes32 private constant TFHEExecutorStorageLocation = - 0xa436a06f0efce5ea38c956a21e24202a59b3b746d48a23fb52b4a5bc33fe3e00; - - function _getTFHEExecutorStorage() internal pure returns (TFHEExecutorStorage storage $) { - assembly { - $.slot := TFHEExecutorStorageLocation - } - } - - function _authorizeUpgrade(address _newImplementation) internal virtual override onlyOwner {} - - /// @notice Getter function for the ACL contract address - function getACLAddress() public view virtual returns (address) { - return address(acl); - } - - /// @notice Getter function for the FHEPayment contract address - function getFHEPaymentAddress() public view virtual returns (address) { - return address(fhePayment); - } - - /// @notice Getter function for the InputVerifier contract address - function getInputVerifierAddress() public view virtual returns (address) { - return address(inputVerifier); - } +import {BaseTFHEExecutor} from "./BaseTFHEExecutor.sol"; +import {UUPSUpgradeable} from "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol"; +contract TFHEExecutor is UUPSUpgradeable, BaseTFHEExecutor { /// @custom:oz-upgrades-unsafe-allow constructor constructor() { _disableInitializers(); } - /// @notice Initializes the contract setting `initialOwner` as the initial owner - function initialize(address initialOwner) external initializer { - __Ownable_init(initialOwner); - } - - enum Operators { - fheAdd, - fheSub, - fheMul, - fheDiv, - fheRem, - fheBitAnd, - fheBitOr, - fheBitXor, - fheShl, - fheShr, - fheRotl, - fheRotr, - fheEq, - fheNe, - fheGe, - fheGt, - fheLe, - fheLt, - fheMin, - fheMax, - fheNeg, - fheNot, - verifyCiphertext, - cast, - trivialEncrypt, - fheIfThenElse, - fheRand, - fheRandBounded - } - - function isPowerOfTwo(uint256 x) internal pure virtual returns (bool) { - return (x > 0) && ((x & (x - 1)) == 0); - } - - /// @dev handle format for user inputs is: keccak256(keccak256(CiphertextFHEList)||index_handle)[0:29] || index_handle || handle_type || handle_version - /// @dev other handles format (fhe ops results) is: keccak256(keccak256(rawCiphertextFHEList)||index_handle)[0:30] || handle_type || handle_version - /// @dev the CiphertextFHEList actually contains: 1 byte (= N) for size of handles_list, N bytes for the handles_types : 1 per handle, then the original fhe160list raw ciphertext - function typeOf(uint256 handle) internal pure virtual returns (uint8) { - uint8 typeCt = uint8(handle >> 8); - return typeCt; - } - - function appendType(uint256 prehandle, uint8 handleType) internal pure virtual returns (uint256 result) { - result = prehandle & 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0000; - result = result | (uint256(handleType) << 8); // append type - result = result | HANDLE_VERSION; - } - - function requireType(uint256 handle, uint256 supportedTypes) internal pure virtual { - uint8 typeCt = typeOf(handle); - require((1 << typeCt) & supportedTypes > 0, "Unsupported type"); - } - - function unaryOp(Operators op, uint256 ct) internal virtual returns (uint256 result) { - require(acl.isAllowed(ct, msg.sender), "Sender doesn't own ct on op"); - result = uint256(keccak256(abi.encodePacked(op, ct, acl, block.chainid))); - uint8 typeCt = typeOf(ct); - result = appendType(result, typeCt); - acl.allowTransient(result, msg.sender); - } - - function binaryOp( - Operators op, - uint256 lhs, - uint256 rhs, - bytes1 scalar, - uint8 resultType - ) internal virtual returns (uint256 result) { - require(acl.isAllowed(lhs, msg.sender), "Sender doesn't own lhs on op"); - if (scalar == 0x00) { - require(acl.isAllowed(rhs, msg.sender), "Sender doesn't own rhs on op"); - uint8 typeRhs = typeOf(rhs); - uint8 typeLhs = typeOf(lhs); - require(typeLhs == typeRhs, "Incompatible types for lhs and rhs"); - } - result = uint256(keccak256(abi.encodePacked(op, lhs, rhs, scalar, acl, block.chainid))); - result = appendType(result, resultType); - acl.allowTransient(result, msg.sender); - } - - function ternaryOp( - Operators op, - uint256 lhs, - uint256 middle, - uint256 rhs - ) internal virtual returns (uint256 result) { - require(acl.isAllowed(lhs, msg.sender), "Sender doesn't own lhs on op"); - require(acl.isAllowed(middle, msg.sender), "Sender doesn't own middle on op"); - require(acl.isAllowed(rhs, msg.sender), "Sender doesn't own rhs on op"); - uint8 typeLhs = typeOf(lhs); - uint8 typeMiddle = typeOf(middle); - uint8 typeRhs = typeOf(rhs); - require(typeLhs == 0, "Unsupported type for lhs"); // lhs must be ebool - require(typeMiddle == typeRhs, "Incompatible types for middle and rhs"); - result = uint256(keccak256(abi.encodePacked(op, lhs, middle, rhs, acl, block.chainid))); - result = appendType(result, typeMiddle); - acl.allowTransient(result, msg.sender); - } - function fheAdd(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheAdd(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheAdd, lhs, rhs, scalar, lhsType); + result = _fheAdd(lhs, rhs, scalarByte); } function fheSub(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheSub(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheSub, lhs, rhs, scalar, lhsType); + result = _fheSub(lhs, rhs, scalarByte); } function fheMul(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheMul(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheMul, lhs, rhs, scalar, lhsType); + result = _fheMul(lhs, rhs, scalarByte); } function fheDiv(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - require(scalarByte & 0x01 == 0x01, "Only fheDiv by a scalar is supported"); - require(rhs != 0, "Could not divide by 0"); - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheDiv(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheDiv, lhs, rhs, scalar, lhsType); + result = _fheDiv(lhs, rhs, scalarByte); } function fheRem(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - require(scalarByte & 0x01 == 0x01, "Only fheRem by a scalar is supported"); - require(rhs != 0, "Could not divide by 0"); - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheRem(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheRem, lhs, rhs, scalar, lhsType); + result = _fheRem(lhs, rhs, scalarByte); } function fheBitAnd(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheBitAnd(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheBitAnd, lhs, rhs, scalar, lhsType); + result = _fheBitAnd(lhs, rhs, scalarByte); } function fheBitOr(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheBitOr(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheBitOr, lhs, rhs, scalar, lhsType); + result = _fheBitOr(lhs, rhs, scalarByte); } function fheBitXor(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheBitXor(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheBitXor, lhs, rhs, scalar, lhsType); + result = _fheBitXor(lhs, rhs, scalarByte); } function fheShl(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheShl(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheShl, lhs, rhs, scalar, lhsType); + result = _fheShl(lhs, rhs, scalarByte); } function fheShr(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheShr(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheShr, lhs, rhs, scalar, lhsType); + result = _fheShr(lhs, rhs, scalarByte); } function fheRotl(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheRotl(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheRotl, lhs, rhs, scalar, lhsType); + result = _fheRotl(lhs, rhs, scalarByte); } - function fheRotr(uint256 lhs, uint256 rhs, bytes1 scalarByte) external returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheRotr(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheRotr, lhs, rhs, scalar, lhsType); + function fheRotr(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { + result = _fheRotr(lhs, rhs, scalarByte); } function fheEq(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + - (1 << 1) + - (1 << 2) + - (1 << 3) + - (1 << 4) + - (1 << 5) + - (1 << 6) + - (1 << 7) + - (1 << 8) + - (1 << 9) + - (1 << 10) + - (1 << 11); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - if (scalar == 0x01) { - require(lhsType <= 8, "Scalar fheEq for ebytesXXX types must use the overloaded fheEq"); - } - fhePayment.payForFheEq(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheEq, lhs, rhs, scalar, 0); + result = _fheEq(lhs, rhs, scalarByte); } function fheEq(uint256 lhs, bytes memory rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 9) + (1 << 10) + (1 << 11); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - require(scalar == 0x01, "Overloaded fheEq is only for scalar ebytesXXX second operand"); - fhePayment.payForFheEq(msg.sender, lhsType, scalar); - require(acl.isAllowed(lhs, msg.sender), "Sender doesn't own lhs on op"); - uint256 lenBytesPT = rhs.length; - if (lhsType == 9) { - require(lenBytesPT == 64, "Bytes array length of Bytes64 should be 64"); - } else if (lhsType == 10) { - require(lenBytesPT == 128, "Bytes array length of Bytes128 should be 128"); - } else { - // @note: i.e lhsType == 11 thanks to the first pre-condition - require(lenBytesPT == 256, "Bytes array length of Bytes256 should be 256"); - } - result = uint256(keccak256(abi.encodePacked(Operators.fheEq, lhs, rhs, scalar, acl, block.chainid))); - result = appendType(result, 0); - acl.allowTransient(result, msg.sender); + result = _fheEq(lhs, rhs, scalarByte); } function fheNe(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + - (1 << 1) + - (1 << 2) + - (1 << 3) + - (1 << 4) + - (1 << 5) + - (1 << 6) + - (1 << 7) + - (1 << 8) + - (1 << 9) + - (1 << 10) + - (1 << 11); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - if (scalar == 0x01) { - require(lhsType <= 8, "Scalar fheNe for ebytesXXX types must use the overloaded fheNe"); - } - fhePayment.payForFheNe(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheNe, lhs, rhs, scalar, 0); + result = _fheNe(lhs, rhs, scalarByte); } function fheNe(uint256 lhs, bytes memory rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 9) + (1 << 10) + (1 << 11); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - require(scalar == 0x01, "Overloaded fheNe is only for scalar ebytesXXX second operand"); - fhePayment.payForFheNe(msg.sender, lhsType, scalar); - require(acl.isAllowed(lhs, msg.sender), "Sender doesn't own lhs on op"); - uint256 lenBytesPT = rhs.length; - if (lhsType == 9) { - require(lenBytesPT == 64, "Bytes array length of Bytes64 should be 64"); - } else if (lhsType == 10) { - require(lenBytesPT == 128, "Bytes array length of Bytes128 should be 128"); - } else { - // @note: i.e lhsType == 11 thanks to the first pre-condition - require(lenBytesPT == 256, "Bytes array length of Bytes256 should be 256"); - } - result = uint256(keccak256(abi.encodePacked(Operators.fheNe, lhs, rhs, scalar, acl, block.chainid))); - result = appendType(result, 0); - acl.allowTransient(result, msg.sender); + result = _fheNe(lhs, rhs, scalarByte); } function fheGe(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheGe(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheGe, lhs, rhs, scalar, 0); + result = _fheGe(lhs, rhs, scalarByte); } function fheGt(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheGt(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheGt, lhs, rhs, scalar, 0); + result = _fheGt(lhs, rhs, scalarByte); } function fheLe(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheLe(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheLe, lhs, rhs, scalar, 0); + result = _fheLe(lhs, rhs, scalarByte); } function fheLt(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheLt(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheLt, lhs, rhs, scalar, 0); + result = _fheLt(lhs, rhs, scalarByte); } function fheMin(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheMin(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheMin, lhs, rhs, scalar, lhsType); + result = _fheMin(lhs, rhs, scalarByte); } function fheMax(uint256 lhs, uint256 rhs, bytes1 scalarByte) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(lhs, supportedTypes); - uint8 lhsType = typeOf(lhs); - bytes1 scalar = scalarByte & 0x01; - fhePayment.payForFheMax(msg.sender, lhsType, scalar); - result = binaryOp(Operators.fheMax, lhs, rhs, scalar, lhsType); + result = _fheMax(lhs, rhs, scalarByte); } function fheNeg(uint256 ct) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(ct, supportedTypes); - uint8 typeCt = typeOf(ct); - fhePayment.payForFheNeg(msg.sender, typeCt); - result = unaryOp(Operators.fheNeg, ct); + result = _fheNeg(ct); } function fheNot(uint256 ct) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - requireType(ct, supportedTypes); - uint8 typeCt = typeOf(ct); - fhePayment.payForFheNot(msg.sender, typeCt); - result = unaryOp(Operators.fheNot, ct); + result = _fheNot(ct); } function verifyCiphertext( @@ -461,150 +112,32 @@ contract TFHEExecutor is UUPSUpgradeable, Ownable2StepUpgradeable { bytes memory inputProof, bytes1 inputType ) external virtual returns (uint256 result) { - ContextUserInputs memory contextUserInputs = ContextUserInputs({ - aclAddress: address(acl), - userAddress: userAddress, - contractAddress: msg.sender - }); - uint8 typeCt = typeOf(uint256(inputHandle)); - require(uint8(inputType) == typeCt, "Wrong type"); - result = inputVerifier.verifyCiphertext(contextUserInputs, inputHandle, inputProof); - acl.allowTransient(result, msg.sender); + result = _verifyCiphertext(inputHandle, userAddress, inputProof, inputType); } function cast(uint256 ct, bytes1 toType) external virtual returns (uint256 result) { - require(acl.isAllowed(ct, msg.sender), "Sender doesn't own ct on cast"); - uint256 supportedTypesInput = (1 << 0) + - (1 << 1) + - (1 << 2) + - (1 << 3) + - (1 << 4) + - (1 << 5) + - (1 << 6) + - (1 << 8); - requireType(ct, supportedTypesInput); - uint256 supportedTypesOutput = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); // @note: unsupported casting to ebool (use fheNe instead) - require((1 << uint8(toType)) & supportedTypesOutput > 0, "Unsupported output type"); - uint8 typeCt = typeOf(ct); - require(bytes1(typeCt) != toType, "Cannot cast to same type"); - fhePayment.payForCast(msg.sender, typeCt); - result = uint256(keccak256(abi.encodePacked(Operators.cast, ct, toType, acl, block.chainid))); - result = appendType(result, uint8(toType)); - acl.allowTransient(result, msg.sender); + result = _cast(ct, toType); } function trivialEncrypt(uint256 pt, bytes1 toType) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + - (1 << 1) + - (1 << 2) + - (1 << 3) + - (1 << 4) + - (1 << 5) + - (1 << 6) + - (1 << 7) + - (1 << 8); - uint8 toT = uint8(toType); - require((1 << toT) & supportedTypes > 0, "Unsupported type"); - fhePayment.payForTrivialEncrypt(msg.sender, toT); - result = uint256(keccak256(abi.encodePacked(Operators.trivialEncrypt, pt, toType, acl, block.chainid))); - result = appendType(result, toT); - acl.allowTransient(result, msg.sender); + result = _trivialEncrypt(pt, toType); } function trivialEncrypt(bytes memory pt, bytes1 toType) external virtual returns (uint256 result) { - // @note: overloaded function for ebytesXX types - uint256 supportedTypes = (1 << 9) + (1 << 10) + (1 << 11); - uint8 toT = uint8(toType); - require((1 << toT) & supportedTypes > 0, "Unsupported type"); - fhePayment.payForTrivialEncrypt(msg.sender, toT); - uint256 lenBytesPT = pt.length; - if (toT == 9) { - require(lenBytesPT == 64, "Bytes array length of Bytes64 should be 64"); - } else if (toT == 10) { - require(lenBytesPT == 128, "Bytes array length of Bytes128 should be 128"); - } else { - // @note: i.e toT == 11 thanks to the pre-condition above - require(lenBytesPT == 256, "Bytes array length of Bytes256 should be 256"); - } - result = uint256(keccak256(abi.encodePacked(Operators.trivialEncrypt, pt, toType, acl, block.chainid))); - result = appendType(result, toT); - acl.allowTransient(result, msg.sender); + result = _trivialEncrypt(pt, toType); } function fheIfThenElse(uint256 control, uint256 ifTrue, uint256 ifFalse) external virtual returns (uint256 result) { - uint256 supportedTypes = (1 << 0) + - (1 << 1) + - (1 << 2) + - (1 << 3) + - (1 << 4) + - (1 << 5) + - (1 << 6) + - (1 << 7) + - (1 << 8) + - (1 << 9) + - (1 << 10) + - (1 << 11); - requireType(ifTrue, supportedTypes); - uint8 typeCt = typeOf(ifTrue); - fhePayment.payForIfThenElse(msg.sender, typeCt); - result = ternaryOp(Operators.fheIfThenElse, control, ifTrue, ifFalse); + result = _fheIfThenElse(control, ifTrue, ifFalse); } function fheRand(bytes1 randType) external virtual returns (uint256 result) { - TFHEExecutorStorage storage $ = _getTFHEExecutorStorage(); - uint256 supportedTypes = (1 << 0) + - (1 << 1) + - (1 << 2) + - (1 << 3) + - (1 << 4) + - (1 << 5) + - (1 << 6) + - (1 << 8) + - (1 << 9) + - (1 << 10) + - (1 << 11); - uint8 randT = uint8(randType); - require((1 << randT) & supportedTypes > 0, "Unsupported erandom type"); - fhePayment.payForFheRand(msg.sender, randT); - bytes16 seed = bytes16( - keccak256(abi.encodePacked($.counterRand, acl, block.chainid, blockhash(block.number - 1), block.timestamp)) - ); - result = uint256(keccak256(abi.encodePacked(Operators.fheRand, randType, seed))); - result = appendType(result, randT); - acl.allowTransient(result, msg.sender); - $.counterRand++; + result = _fheRand(randType); } function fheRandBounded(uint256 upperBound, bytes1 randType) external virtual returns (uint256 result) { - TFHEExecutorStorage storage $ = _getTFHEExecutorStorage(); - uint256 supportedTypes = (1 << 1) + (1 << 2) + (1 << 3) + (1 << 4) + (1 << 5) + (1 << 6) + (1 << 8); - uint8 randT = uint8(randType); - require((1 << randT) & supportedTypes > 0, "Unsupported erandom type"); - require(isPowerOfTwo(upperBound), "UpperBound must be a power of 2"); - fhePayment.payForFheRandBounded(msg.sender, randT); - bytes16 seed = bytes16( - keccak256(abi.encodePacked($.counterRand, acl, block.chainid, blockhash(block.number - 1), block.timestamp)) - ); - result = uint256(keccak256(abi.encodePacked(Operators.fheRandBounded, upperBound, randType, seed))); - result = appendType(result, randT); - acl.allowTransient(result, msg.sender); - $.counterRand++; - } - - /// @notice Getter for the name and version of the contract - /// @return string representing the name and the version of the contract - function getVersion() external pure virtual returns (string memory) { - return - string( - abi.encodePacked( - CONTRACT_NAME, - " v", - Strings.toString(MAJOR_VERSION), - ".", - Strings.toString(MINOR_VERSION), - ".", - Strings.toString(PATCH_VERSION) - ) - ); + result = _fheRandBounded(upperBound, randType); } + + function _authorizeUpgrade(address _newImplementation) internal virtual override onlyOwner {} }