diff --git a/ethereum/contracts/query/QueryDemo.sol b/ethereum/contracts/query/QueryDemo.sol index 12c0c3c9d5..b19c543767 100644 --- a/ethereum/contracts/query/QueryDemo.sol +++ b/ethereum/contracts/query/QueryDemo.sol @@ -101,10 +101,10 @@ contract QueryDemo is QueryResponse { EthCallQueryResponse memory eqr = parseEthCallQueryResponse(r.responses[i]); // Validate that update is not obsolete - validateBlockNum(eqr.blockNum, chainEntry.blockNum, block.number); + validateBlockNum(eqr.blockNum, chainEntry.blockNum); // Validate that update is not stale - validateBlockTime(eqr.blockTime, block.timestamp - 300, block.timestamp); + validateBlockTime(eqr.blockTime, block.timestamp - 300); if (eqr.result.length != 1) { revert UnexpectedResultMismatch(); diff --git a/ethereum/contracts/query/QueryResponse.sol b/ethereum/contracts/query/QueryResponse.sol index c5896401b3..4c93722a6b 100644 --- a/ethereum/contracts/query/QueryResponse.sol +++ b/ethereum/contracts/query/QueryResponse.sol @@ -77,8 +77,8 @@ error InvalidPayloadLength(uint256 received, uint256 expected); error InvalidContractAddress(); error InvalidFunctionSignature(); error InvalidChainId(); -error InvalidBlockNum(); -error InvalidBlockTime(); +error StaleBlockNum(); +error StaleBlockTime(); // @dev QueryResponse is a library that implements the parsing and verification of Cross Chain Query (CCQ) responses. abstract contract QueryResponse { @@ -357,34 +357,37 @@ abstract contract QueryResponse { checkLength(pcr.response, respIdx); } - /// @dev validateBlockTime validates that the parsed block time is in an acceptable range + /// @dev validateBlockTime validates that the parsed block time isn't stale /// @param _blockTime Wormhole block time in MICROseconds /// @param _minBlockTime Minium block time in seconds - /// @param _maxBlockTime Maximum block time in seconds - function validateBlockTime(uint64 _blockTime, uint256 _minBlockTime, uint256 _maxBlockTime) public pure { + function validateBlockTime(uint64 _blockTime, uint256 _minBlockTime) public pure { uint256 blockTimeInSeconds = _blockTime / 1_000_000; // Rounds down - if (blockTimeInSeconds < _minBlockTime || blockTimeInSeconds > _maxBlockTime) { - revert InvalidBlockTime(); + if (blockTimeInSeconds < _minBlockTime) { + revert StaleBlockTime(); } } - /// @dev validateBlockNum validates that the parsed blockNum is in an acceptable range - function validateBlockNum(uint64 _blockNum, uint256 _minBlockNum, uint256 _maxBlockNum) public pure { - if (_blockNum < _minBlockNum || _blockNum > _maxBlockNum) { - revert InvalidBlockNum(); + /// @dev validateBlockNum validates that the parsed blockNum isn't stale + function validateBlockNum(uint64 _blockNum, uint256 _minBlockNum) public pure { + if (_blockNum < _minBlockNum) { + revert StaleBlockNum(); } } /// @dev validateChainId validates that the parsed chainId is one of an array of chainIds we expect function validateChainId(uint16 chainId, uint16[] memory _validChainIds) public pure { bool validChainId = false; + + uint256 numChainIds = _validChainIds.length; - for (uint256 i = 0; i < _validChainIds.length; ++i) { - if (chainId == _validChainIds[i]) { + for (uint256 idx = 0; idx < numChainIds;) { + if (chainId == _validChainIds[idx]) { validChainId = true; break; } + + unchecked { ++idx; } } if (!validChainId) revert InvalidChainId(); @@ -392,8 +395,12 @@ abstract contract QueryResponse { /// @dev validateMutlipleEthCallData validates that each EthCallData in an array comes from a function signature and contract address we expect function validateMultipleEthCallData(EthCallData[] memory r, address[] memory _expectedContractAddresses, bytes4[] memory _expectedFunctionSignatures) public pure { - for (uint256 i = 0; i < r.length; ++i) { - validateEthCallData(r[i], _expectedContractAddresses, _expectedFunctionSignatures); + uint256 callDatasLength = r.length; + + for (uint256 idx = 0; idx < callDatasLength;) { + validateEthCallData(r[idx], _expectedContractAddresses, _expectedFunctionSignatures); + + unchecked { ++idx; } } } @@ -403,12 +410,16 @@ abstract contract QueryResponse { bool validContractAddress = _expectedContractAddresses.length == 0 ? true : false; bool validFunctionSignature = _expectedFunctionSignatures.length == 0 ? true : false; + uint256 contractAddressesLength = _expectedContractAddresses.length; + // Check that the contract address called in the request is expected - for (uint256 i = 0; i < _expectedContractAddresses.length; ++i) { - if (r.contractAddress == _expectedContractAddresses[i]) { + for (uint256 idx = 0; idx < contractAddressesLength;) { + if (r.contractAddress == _expectedContractAddresses[idx]) { validContractAddress = true; break; } + + unchecked { ++idx; } } // Early exit to save gas @@ -416,13 +427,17 @@ abstract contract QueryResponse { revert InvalidContractAddress(); } + uint256 functionSignaturesLength = _expectedFunctionSignatures.length; + // Check that the function signature called is expected - for (uint256 i = 0; i < _expectedFunctionSignatures.length; ++i) { + for (uint256 idx = 0; idx < functionSignaturesLength;) { (bytes4 funcSig,) = r.callData.asBytes4Unchecked(0); - if (funcSig == _expectedFunctionSignatures[i]) { + if (funcSig == _expectedFunctionSignatures[idx]) { validFunctionSignature = true; break; } + + unchecked { ++idx; } } if (!validFunctionSignature) { diff --git a/ethereum/forge-test/query/QueryResponse.t.sol b/ethereum/forge-test/query/QueryResponse.t.sol index 4ce8cde6cc..1915780133 100644 --- a/ethereum/forge-test/query/QueryResponse.t.sol +++ b/ethereum/forge-test/query/QueryResponse.t.sol @@ -481,34 +481,32 @@ contract TestQueryResponse is Test { queryResponse.verifyQueryResponseSignatures(resp, signatures); } - function testFuzz_validateBlockTime_success(uint256 _blockTime, uint256 _minBlockTime, uint256 _maxBlockTime) public view { + function testFuzz_validateBlockTime_success(uint256 _blockTime, uint256 _minBlockTime) public view { _blockTime = bound(_blockTime, 0, type(uint64).max/1_000_000); vm.assume(_blockTime >= _minBlockTime); - vm.assume(_blockTime <= _maxBlockTime); - queryResponse.validateBlockTime(uint64(_blockTime * 1_000_000), _minBlockTime, _maxBlockTime); + queryResponse.validateBlockTime(uint64(_blockTime * 1_000_000), _minBlockTime); } - function testFuzz_validateBlockTime_fail(uint256 _blockTime, uint256 _minBlockTime, uint256 _maxBlockTime) public { + function testFuzz_validateBlockTime_fail(uint256 _blockTime, uint256 _minBlockTime) public { _blockTime = bound(_blockTime, 0, type(uint64).max/1_000_000); - vm.assume(_blockTime < _minBlockTime || _blockTime > _maxBlockTime); + vm.assume(_blockTime < _minBlockTime); - vm.expectRevert(InvalidBlockTime.selector); - queryResponse.validateBlockTime(uint64(_blockTime * 1_000_000), _minBlockTime, _maxBlockTime); + vm.expectRevert(StaleBlockTime.selector); + queryResponse.validateBlockTime(uint64(_blockTime * 1_000_000), _minBlockTime); } - function testFuzz_validateBlockNum_success(uint64 _blockNum, uint256 _minBlockNum, uint256 _maxBlockNum) public view { + function testFuzz_validateBlockNum_success(uint64 _blockNum, uint256 _minBlockNum) public view { vm.assume(_blockNum >= _minBlockNum); - vm.assume(_blockNum <= _maxBlockNum); - queryResponse.validateBlockNum(_blockNum, _minBlockNum, _maxBlockNum); + queryResponse.validateBlockNum(_blockNum, _minBlockNum); } - function testFuzz_validateBlockNum_fail(uint64 _blockNum, uint256 _minBlockNum, uint256 _maxBlockNum) public { - vm.assume(_blockNum < _minBlockNum || _blockNum > _maxBlockNum); + function testFuzz_validateBlockNum_fail(uint64 _blockNum, uint256 _minBlockNum) public { + vm.assume(_blockNum < _minBlockNum); - vm.expectRevert(InvalidBlockNum.selector); - queryResponse.validateBlockNum(_blockNum, _minBlockNum, _maxBlockNum); + vm.expectRevert(StaleBlockNum.selector); + queryResponse.validateBlockNum(_blockNum, _minBlockNum); } function testFuzz_validateChainId_success(uint16 _validChainIndex, uint16[] memory _validChainIds) public view {