Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CCQ: High level validation helpers + slight refactor #3537

Merged
merged 4 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 15 additions & 26 deletions ethereum/contracts/query/QueryDemo.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ import "./QueryResponse.sol";
error InvalidOwner();
// @dev for the onlyOwner modifier
error InvalidCaller();
error InvalidContractAddress();
error InvalidCalldata();
error InvalidWormholeAddress();
error InvalidForeignChainID();
error ObsoleteUpdate();
error StaleUpdate();
error UnexpectedCallData();
error UnexpectedResultLength();
error UnexpectedResultMismatch();

Expand All @@ -32,14 +31,13 @@ contract QueryDemo is QueryResponse {
}

address private immutable owner;
address private immutable wormhole;
uint16 private immutable myChainID;
mapping(uint16 => ChainEntry) private counters;
uint16[] private foreignChainIDs;

bytes4 GetMyCounter = bytes4(hex"916d5743");

constructor(address _owner, address _wormhole, uint16 _myChainID) {
constructor(address _owner, address _wormhole, uint16 _myChainID) QueryResponse(_wormhole) {
if (_owner == address(0)) {
revert InvalidOwner();
}
Expand All @@ -48,7 +46,7 @@ contract QueryDemo is QueryResponse {
if (_wormhole == address(0)) {
revert InvalidWormholeAddress();
}
wormhole = _wormhole;

myChainID = _myChainID;
counters[_myChainID] = ChainEntry(_myChainID, address(this), 0, 0, 0);
}
Expand Down Expand Up @@ -87,7 +85,7 @@ contract QueryDemo is QueryResponse {
// @notice Takes the cross chain query response for the other counters, stores the results for the other chains, and updates the counter for this chain.
function updateCounters(bytes memory response, IWormhole.Signature[] memory signatures) public {
uint256 adjustedBlockTime;
ParsedQueryResponse memory r = parseAndVerifyQueryResponse(address(wormhole), response, signatures);
ParsedQueryResponse memory r = parseAndVerifyQueryResponse(response, signatures);
uint numResponses = r.responses.length;
if (numResponses != foreignChainIDs.length) {
revert UnexpectedResultLength();
Expand All @@ -101,33 +99,24 @@ contract QueryDemo is QueryResponse {
}

EthCallQueryResponse memory eqr = parseEthCallQueryResponse(r.responses[i]);
if (eqr.blockNum <= chainEntry.blockNum) {
revert ObsoleteUpdate();
}

// wormhole time is in microseconds, timestamp is in seconds
adjustedBlockTime = eqr.blockTime / 1_000_000;
if (adjustedBlockTime <= block.timestamp - 300) {
revert StaleUpdate();
}
// Validate that update is not obsolete
validateBlockNum(eqr.blockNum, chainEntry.blockNum);

// Validate that update is not stale
validateBlockTime(eqr.blockTime, block.timestamp - 300);

if (eqr.result.length != 1) {
revert UnexpectedResultMismatch();
}

if (eqr.result[0].contractAddress != chainEntry.contractAddress) {
revert InvalidContractAddress();
}
// Validate addresses and function signatures
address[] memory validAddresses = new address[](1);
bytes4[] memory validFunctionSignatures = new bytes4[](1);
validAddresses[0] = chainEntry.contractAddress;
validFunctionSignatures[0] = GetMyCounter;

// TODO: Is there an easier way to verify that the call data is correct!
bytes memory callData = eqr.result[0].callData;
bytes4 result;
assembly {
result := mload(add(callData, 32))
}
if (result != GetMyCounter) {
revert UnexpectedCallData();
}
validateMultipleEthCallData(eqr.result, validAddresses, validFunctionSignatures);

require(eqr.result[0].result.length == 32, "result is not a uint256");

Expand Down
118 changes: 114 additions & 4 deletions ethereum/contracts/query/QueryResponse.sol
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct EthCallData {
}

// Custom errors
error EmptyWormholeAddress();
error InvalidResponseVersion();
error VersionMismatch();
error ZeroQueries();
Expand All @@ -73,18 +74,33 @@ error RequestTypeMismatch();
error UnsupportedQueryType();
error UnexpectedNumberOfResults();
error InvalidPayloadLength(uint256 received, uint256 expected);
error InvalidContractAddress();
error InvalidFunctionSignature();
error InvalidChainId();
error StaleBlockNum();
error StaleBlockTime();

// @dev QueryResponse is a library that implements the parsing and verification of Cross Chain Query (CCQ) responses.
abstract contract QueryResponse {
using BytesParsing for bytes;

IWormhole public immutable wormhole;
djb15 marked this conversation as resolved.
Show resolved Hide resolved

bytes public constant responsePrefix = bytes("query_response_0000000000000000000|");
uint8 public constant VERSION = 1;
uint8 public constant QT_ETH_CALL = 1;
uint8 public constant QT_ETH_CALL_BY_TIMESTAMP = 2;
uint8 public constant QT_ETH_CALL_WITH_FINALITY = 3;
uint8 public constant QT_MAX = 4; // Keep this last

constructor(address _wormhole) {
if (_wormhole == address(0)) {
revert EmptyWormholeAddress();
}

wormhole = IWormhole(_wormhole);
}

/// @dev getResponseHash computes the hash of the specified query response.
function getResponseHash(bytes memory response) public pure returns (bytes32) {
return keccak256(response);
Expand All @@ -96,8 +112,8 @@ abstract contract QueryResponse {
}

/// @dev parseAndVerifyQueryResponse verifies the query response and returns the parsed response.
function parseAndVerifyQueryResponse(address wormhole, bytes memory response, IWormhole.Signature[] memory signatures) public view returns (ParsedQueryResponse memory r) {
verifyQueryResponseSignatures(wormhole, response, signatures);
function parseAndVerifyQueryResponse(bytes memory response, IWormhole.Signature[] memory signatures) public view returns (ParsedQueryResponse memory r) {
verifyQueryResponseSignatures(response, signatures);

uint index = 0;

Expand Down Expand Up @@ -341,13 +357,107 @@ abstract contract QueryResponse {
checkLength(pcr.response, respIdx);
}

/// @dev validateBlockTime validates that the parsed block time isn't stale
/// @param _blockTime Wormhole block time in MICROseconds
/// @param _minBlockTime Minium block time in seconds
function validateBlockTime(uint64 _blockTime, uint256 _minBlockTime) public pure {
uint256 blockTimeInSeconds = _blockTime / 1_000_000; // Rounds down

if (blockTimeInSeconds < _minBlockTime) {
revert StaleBlockTime();
}
}

/// @dev validateBlockNum validates that the parsed blockNum isn't stale
function validateBlockNum(uint64 _blockNum, uint256 _minBlockNum) public pure {
if (_blockNum < _minBlockNum) {
revert StaleBlockNum();
}
}

/// @dev validateChainId validates that the parsed chainId is one of an array of chainIds we expect
djb15 marked this conversation as resolved.
Show resolved Hide resolved
function validateChainId(uint16 chainId, uint16[] memory _validChainIds) public pure {
bool validChainId = false;

uint256 numChainIds = _validChainIds.length;

for (uint256 idx = 0; idx < numChainIds;) {
if (chainId == _validChainIds[idx]) {
validChainId = true;
break;
}

unchecked { ++idx; }
}

if (!validChainId) revert InvalidChainId();
}

/// @dev validateMutlipleEthCallData validates that each EthCallData in an array comes from a function signature and contract address we expect
function validateMultipleEthCallData(EthCallData[] memory r, address[] memory _expectedContractAddresses, bytes4[] memory _expectedFunctionSignatures) public pure {
uint256 callDatasLength = r.length;

for (uint256 idx = 0; idx < callDatasLength;) {
validateEthCallData(r[idx], _expectedContractAddresses, _expectedFunctionSignatures);

unchecked { ++idx; }
}
}

/// @dev validateEthCallData validates that EthCallData comes from a function signature and contract address we expect
/// @dev An empty array means we accept all addresses/function signatures
/// @dev Example 1: To accept signatures 0xaaaaaaaa and 0xbbbbbbbb from `address(abcd)` you'd pass in [0xaaaaaaaa, 0xbbbbbbbb], [address(abcd)]
/// @dev Example 2: To accept any function signatures from `address(abcd)` or `address(efab)` you'd pass in [], [address(abcd), address(efab)]
/// @dev Example 3: To accept function signature 0xaaaaaaaa from any address you'd pass in [0xaaaaaaaa], []
/// @dev WARNING Example 4: If you want to accept signature 0xaaaaaaaa from `address(abcd)` and signature 0xbbbbbbbb from `address(efab)` the following input would be incorrect:
/// @dev [0xaaaaaaaa, 0xbbbbbbbb], [address(abcd), address(efab)]
/// @dev This would accept both 0xaaaaaaaa and 0xbbbbbbbb from `address(abcd)` AND `address(efab)`. Instead you should make 2 calls to this method
/// @dev using the pattern in Example 1. [0xaaaaaaaa], [address(abcd)] OR [0xbbbbbbbb], [address(efab)]
function validateEthCallData(EthCallData memory r, address[] memory _expectedContractAddresses, bytes4[] memory _expectedFunctionSignatures) public pure {
bool validContractAddress = _expectedContractAddresses.length == 0 ? true : false;
bool validFunctionSignature = _expectedFunctionSignatures.length == 0 ? true : false;

uint256 contractAddressesLength = _expectedContractAddresses.length;

// Check that the contract address called in the request is expected
for (uint256 idx = 0; idx < contractAddressesLength;) {
if (r.contractAddress == _expectedContractAddresses[idx]) {
validContractAddress = true;
break;
}

unchecked { ++idx; }
}

// Early exit to save gas
if (!validContractAddress) {
revert InvalidContractAddress();
}

uint256 functionSignaturesLength = _expectedFunctionSignatures.length;

// Check that the function signature called is expected
for (uint256 idx = 0; idx < functionSignaturesLength;) {
(bytes4 funcSig,) = r.callData.asBytes4Unchecked(0);
if (funcSig == _expectedFunctionSignatures[idx]) {
validFunctionSignature = true;
break;
}

unchecked { ++idx; }
}

if (!validFunctionSignature) {
revert InvalidFunctionSignature();
}
}

/**
* @dev verifyQueryResponseSignatures verifies the signatures on a query response. It calls into the Wormhole contract.
* IWormhole.Signature expects the last byte to be bumped by 27
* see https://github.com/wormhole-foundation/wormhole/blob/637b1ee657de7de05f783cbb2078dd7d8bfda4d0/ethereum/contracts/Messages.sol#L174
*/
function verifyQueryResponseSignatures(address _wormhole, bytes memory response, IWormhole.Signature[] memory signatures) public view {
IWormhole wormhole = IWormhole(_wormhole);
function verifyQueryResponseSignatures(bytes memory response, IWormhole.Signature[] memory signatures) public view {
// It might be worth adding a verifyCurrentQuorum call on the core bridge so that there is only 1 cross call instead of 4.
uint32 gsi = wormhole.getCurrentGuardianSetIndex();
IWormhole.GuardianSet memory guardianSet = wormhole.getGuardianSet(gsi);
Expand Down
Loading
Loading