Skip to content

Commit

Permalink
feat: skip signature length encoding on final sig
Browse files Browse the repository at this point in the history
  • Loading branch information
jaypaik committed Oct 8, 2024
1 parent 7cc3457 commit 43d9d45
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 197 deletions.
50 changes: 16 additions & 34 deletions src/libraries/SparseCalldataSegmentLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,26 @@ library SparseCalldataSegmentLib {

/// @notice Splits out a segment of calldata, sparsely-packed.
/// The expected format is:
/// [uint32(len(segment0)), segment0, uint32(len(segment1)), segment1, ... uint32(len(segmentN)), segmentN]
/// [uint8(index0), uint32(len(segment0)), segment0, uint8(index1), uint32(len(segment1)), segment1,
/// ... uint8(indexN), uint32(len(segmentN)), segmentN]
/// @param source The calldata to extract the segment from.
/// @return segment The extracted segment. Using the above example, this would be segment0.
/// @return remainder The remaining calldata. Using the above example,
/// this would start at uint32(len(segment1)) and continue to the end at segmentN.
/// this would start at uint8(index1) and continue to the end at segmentN.
function getNextSegment(bytes calldata source)
internal
pure
returns (bytes calldata segment, bytes calldata remainder)
{
// The first 4 bytes hold the length of the segment, excluding the index.
uint32 length = uint32(bytes4(source[:4]));
// The first byte of the segment is the index.
// The next 4 bytes hold the length of the segment, excluding the index.
uint32 length = uint32(bytes4(source[1:5]));

// The offset of the remainder of the calldata.
uint256 remainderOffset = 4 + length;
uint256 remainderOffset = 5 + length;

// The segment is the next `length` + 1 bytes, to account for the index.
// By convention, the first byte of each segment is the index of the segment.
segment = source[4:remainderOffset];
// The segment is the next `length` bytes after the first 5 bytes.
segment = source[5:remainderOffset];

// The remainder is the rest of the calldata.
remainder = source[remainderOffset:];
Expand All @@ -52,7 +53,7 @@ library SparseCalldataSegmentLib {
pure
returns (bytes memory, bytes calldata)
{
uint8 nextIndex = peekIndex(source);
uint8 nextIndex = getIndex(source);

if (nextIndex < index) {
revert SegmentOutOfOrder();
Expand All @@ -61,8 +62,6 @@ library SparseCalldataSegmentLib {
if (nextIndex == index) {
(bytes calldata segment, bytes calldata remainder) = getNextSegment(source);

segment = getBody(segment);

if (segment.length == 0) {
revert NonCanonicalEncoding();
}
Expand All @@ -73,25 +72,16 @@ library SparseCalldataSegmentLib {
return ("", source);
}

/// @notice Extracts the final segment from the source.
/// @dev Reverts if the index of the segment is not RESERVED_VALIDATION_DATA_INDEX.
/// @param source The calldata to extract the segment from.
/// @return The final segment.
function getFinalSegment(bytes calldata source) internal pure returns (bytes calldata) {
(bytes calldata segment, bytes calldata remainder) = getNextSegment(source);

if (getIndex(segment) != RESERVED_VALIDATION_DATA_INDEX) {
if (getIndex(source) != RESERVED_VALIDATION_DATA_INDEX) {
revert ValidationSignatureSegmentMissing();
}

if (remainder.length != 0) {
revert NonCanonicalEncoding();
}

return getBody(segment);
}

/// @notice Returns the index of the next segment in the source.
/// @param source The calldata to extract the index from.
/// @return The index of the next segment.
function peekIndex(bytes calldata source) internal pure returns (uint8) {
return uint8(source[4]);
return source[1:];
}

/// @notice Extracts the index from a segment.
Expand All @@ -101,12 +91,4 @@ library SparseCalldataSegmentLib {
function getIndex(bytes calldata segment) internal pure returns (uint8) {
return uint8(segment[0]);
}

/// @notice Extracts the body from a segment.
/// @dev The body is the segment without the index.
/// @param segment The segment to extract the body from
/// @return The body of the segment.
function getBody(bytes calldata segment) internal pure returns (bytes calldata) {
return segment[1:];
}
}
27 changes: 17 additions & 10 deletions src/libraries/ValidationConfigLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@ pragma solidity ^0.8.20;

import {ModuleEntity, ValidationConfig} from "../interfaces/IModularAccount.sol";

// Validation flags layout:
// 0b00000___ // unused
// 0b_____A__ // isGlobal
// 0b______B_ // isSignatureValidation
// 0b_______C // isUserOpValidation
type ValidationFlags is uint8;

// Validation config is a packed representation of a validation function and flags for its configuration.
// Layout:
// 0xAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA________________________ // Address
Expand Down Expand Up @@ -63,22 +70,22 @@ library ValidationConfigLib {
function unpackUnderlying(ValidationConfig config)
internal
pure
returns (address _module, uint32 _entityId, uint8 flags)
returns (address _module, uint32 _entityId, ValidationFlags flags)
{
bytes25 configBytes = ValidationConfig.unwrap(config);
_module = address(bytes20(configBytes));
_entityId = uint32(bytes4(configBytes << 160));
flags = uint8(configBytes[24]);
flags = ValidationFlags.wrap(uint8(configBytes[24]));
}

function unpack(ValidationConfig config)
internal
pure
returns (ModuleEntity _validationFunction, uint8 flags)
returns (ModuleEntity _validationFunction, ValidationFlags flags)
{
bytes25 configBytes = ValidationConfig.unwrap(config);
_validationFunction = ModuleEntity.wrap(bytes24(configBytes));
flags = uint8(configBytes[24]);
flags = ValidationFlags.wrap(uint8(configBytes[24]));
}

function module(ValidationConfig config) internal pure returns (address) {
Expand All @@ -97,23 +104,23 @@ library ValidationConfigLib {
return ValidationConfig.unwrap(config) & _VALIDATION_FLAG_IS_GLOBAL != 0;
}

function isGlobal(uint8 flags) internal pure returns (bool) {
return flags & 0x04 != 0;
function isGlobal(ValidationFlags flags) internal pure returns (bool) {
return ValidationFlags.unwrap(flags) & 0x04 != 0;
}

function isSignatureValidation(ValidationConfig config) internal pure returns (bool) {
return ValidationConfig.unwrap(config) & _VALIDATION_FLAG_IS_SIGNATURE != 0;
}

function isSignatureValidation(uint8 flags) internal pure returns (bool) {
return flags & 0x02 != 0;
function isSignatureValidation(ValidationFlags flags) internal pure returns (bool) {
return ValidationFlags.unwrap(flags) & 0x02 != 0;
}

function isUserOpValidation(ValidationConfig config) internal pure returns (bool) {
return ValidationConfig.unwrap(config) & _VALIDATION_FLAG_IS_USER_OP != 0;
}

function isUserOpValidation(uint8 flags) internal pure returns (bool) {
return flags & 0x01 != 0;
function isUserOpValidation(ValidationFlags flags) internal pure returns (bool) {
return ValidationFlags.unwrap(flags) & 0x01 != 0;
}
}
2 changes: 1 addition & 1 deletion test/account/AccountReturnData.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ contract AccountReturnDataTest is AccountTestBase {

// Tests the ability to read data via executeWithRuntimeValidation
function test_returnData_authorized_exec() public {
bool result = ResultConsumerModule(address(account1)).checkResultExecuteWithAuthorization(
bool result = ResultConsumerModule(address(account1)).checkResultExecuteWithRuntimeValidation(
address(regularResultContract), keccak256("bar")
);

Expand Down
44 changes: 0 additions & 44 deletions test/account/PerHookData.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -246,34 +246,6 @@ contract PerHookDataTest is CustomValidationTestBase {
entryPoint.handleOps(userOps, beneficiary);
}

function test_failPerHookData_excessData_userOp() public {
(PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP();
(uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash());

PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)});

userOp.signature = abi.encodePacked(
_encodeSignature(
_signerValidation, GLOBAL_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v)
),
"extra data"
);

PackedUserOperation[] memory userOps = new PackedUserOperation[](1);
userOps[0] = userOp;

vm.expectRevert(
abi.encodeWithSelector(
IEntryPoint.FailedOpWithRevert.selector,
0,
"AA23 reverted",
abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector)
)
);
entryPoint.handleOps(userOps, beneficiary);
}

function test_passAccessControl_runtime() public {
assertEq(_counter.number(), 0);

Expand Down Expand Up @@ -420,22 +392,6 @@ contract PerHookDataTest is CustomValidationTestBase {
);
}

function test_failPerHookData_excessData_runtime() public {
PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)});

vm.prank(owner1);
vm.expectRevert(abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector));
account1.executeWithRuntimeValidation(
abi.encodeCall(
ReferenceModularAccount.execute, (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ()))
),
abi.encodePacked(
_encodeSignature(_signerValidation, GLOBAL_VALIDATION, preValidationHookData, ""), "extra data"
)
);
}

function test_pass1271AccessControl() public {
string memory message = "Hello, world!";

Expand Down
8 changes: 4 additions & 4 deletions test/libraries/SparseCalldataSegmentLib.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ contract SparseCalldataSegmentLibTest is Test {
bytes memory result = "";

for (uint256 i = 0; i < segments.length; i++) {
result = abi.encodePacked(result, uint32(segments[i].length), segments[i]);
result = abi.encodePacked(result, uint8(0), uint32(segments[i].length), segments[i]);
}

return result;
Expand All @@ -65,7 +65,7 @@ contract SparseCalldataSegmentLibTest is Test {
bytes memory result = "";

for (uint256 i = 0; i < segments.length; i++) {
result = abi.encodePacked(result, uint32(segments[i].length + 1), indices[i], segments[i]);
result = abi.encodePacked(result, indices[i], uint32(segments[i].length), segments[i]);
}

return result;
Expand Down Expand Up @@ -99,10 +99,10 @@ contract SparseCalldataSegmentLibTest is Test {

uint256 index = 0;
while (remainder.length > 0) {
indices[index] = remainder.getIndex();
bytes calldata segment;
(segment, remainder) = remainder.getNextSegment();
bodies[index] = segment.getBody();
indices[index] = segment.getIndex();
bodies[index] = segment;
index++;
}

Expand Down
12 changes: 7 additions & 5 deletions test/libraries/ValidationConfigLib.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ pragma solidity ^0.8.20;
import {Test} from "forge-std/Test.sol";

import {ModuleEntity, ModuleEntityLib} from "../../src/libraries/ModuleEntityLib.sol";
import {ValidationConfig, ValidationConfigLib} from "../../src/libraries/ValidationConfigLib.sol";
import {
ValidationConfig, ValidationConfigLib, ValidationFlags
} from "../../src/libraries/ValidationConfigLib.sol";

contract ValidationConfigLibTest is Test {
using ModuleEntityLib for ModuleEntity;
Expand All @@ -23,7 +25,7 @@ contract ValidationConfigLibTest is Test {
ValidationConfigLib.pack(module, entityId, isGlobal, isSignatureValidation, isUserOpValidation);

// Test unpacking underlying
(address module2, uint32 entityId2, uint8 flags2) = validationConfig.unpackUnderlying();
(address module2, uint32 entityId2, ValidationFlags flags2) = validationConfig.unpackUnderlying();

assertEq(module, module2, "module mismatch");
assertEq(entityId, entityId2, "entityId mismatch");
Expand All @@ -35,7 +37,7 @@ contract ValidationConfigLibTest is Test {

ModuleEntity expectedModuleEntity = ModuleEntityLib.pack(module, entityId);

(ModuleEntity validationFunction, uint8 flags3) = validationConfig.unpack();
(ModuleEntity validationFunction, ValidationFlags flags3) = validationConfig.unpack();

assertEq(
ModuleEntity.unwrap(validationFunction),
Expand Down Expand Up @@ -73,7 +75,7 @@ contract ValidationConfigLibTest is Test {

(address expectedModule, uint32 expectedEntityId) = validationFunction.unpack();

(address module, uint32 entityId, uint8 flags2) = validationConfig.unpackUnderlying();
(address module, uint32 entityId, ValidationFlags flags2) = validationConfig.unpackUnderlying();

assertEq(expectedModule, module, "module mismatch");
assertEq(expectedEntityId, entityId, "entityId mismatch");
Expand All @@ -83,7 +85,7 @@ contract ValidationConfigLibTest is Test {

// Test unpacking to ModuleEntity

(ModuleEntity validationFunction2, uint8 flags3) = validationConfig.unpack();
(ModuleEntity validationFunction2, ValidationFlags flags3) = validationConfig.unpack();

assertEq(
ModuleEntity.unwrap(validationFunction),
Expand Down
18 changes: 8 additions & 10 deletions test/mocks/modules/ReturnDataModuleMocks.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ pragma solidity ^0.8.20;

import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";

import {DIRECT_CALL_VALIDATION_ENTITYID} from "../../../src/helpers/Constants.sol";
import {
ExecutionManifest,
IExecutionModule,
ManifestExecutionFunction
} from "../../../src/interfaces/IExecutionModule.sol";

import {DIRECT_CALL_VALIDATION_ENTITYID} from "../../../src/helpers/Constants.sol";

import {IModularAccount} from "../../../src/interfaces/IModularAccount.sol";
import {IValidationModule} from "../../../src/interfaces/IValidationModule.sol";

import {ModuleEntityLib} from "../../../src/libraries/ModuleEntityLib.sol";
import {BaseModule} from "../../../src/modules/BaseModule.sol";

import {ModuleSignatureUtils} from "../../utils/ModuleSignatureUtils.sol";

contract RegularResultContract {
function foo() external pure returns (bytes32) {
return keccak256("bar");
Expand Down Expand Up @@ -62,7 +62,7 @@ contract ResultCreatorModule is IExecutionModule, BaseModule {
}
}

contract ResultConsumerModule is IExecutionModule, BaseModule, IValidationModule {
contract ResultConsumerModule is IExecutionModule, BaseModule, IValidationModule, ModuleSignatureUtils {
ResultCreatorModule public immutable RESULT_CREATOR;
RegularResultContract public immutable REGULAR_RESULT_CONTRACT;

Expand Down Expand Up @@ -102,13 +102,11 @@ contract ResultConsumerModule is IExecutionModule, BaseModule, IValidationModule
}

// Check the return data through the execute with authorization case
function checkResultExecuteWithAuthorization(address target, bytes32 expected) external returns (bool) {
function checkResultExecuteWithRuntimeValidation(address target, bytes32 expected) external returns (bool) {
// This result should be allowed based on the manifest permission request
bytes memory returnData = IModularAccount(msg.sender).executeWithRuntimeValidation(
abi.encodeCall(IModularAccount.execute, (target, 0, abi.encodeCall(RegularResultContract.foo, ()))),
abi.encodePacked(this, DIRECT_CALL_VALIDATION_ENTITYID, uint8(0), uint32(1), uint8(255)) // Validation
// function of self,
// selector-associated, with no auth data
_encodeSignature(ModuleEntityLib.pack(address(this), DIRECT_CALL_VALIDATION_ENTITYID), uint8(0), "")
);

bytes32 actual = abi.decode(abi.decode(returnData, (bytes)), (bytes32));
Expand All @@ -130,7 +128,7 @@ contract ResultConsumerModule is IExecutionModule, BaseModule, IValidationModule
allowGlobalValidation: false
});
manifest.executionFunctions[1] = ManifestExecutionFunction({
executionSelector: this.checkResultExecuteWithAuthorization.selector,
executionSelector: this.checkResultExecuteWithRuntimeValidation.selector,
skipRuntimeValidation: true,
allowGlobalValidation: false
});
Expand Down
Loading

0 comments on commit 43d9d45

Please sign in to comment.