Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: skip signature length encoding on final sig and add type safety #192

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 16 additions & 34 deletions src/libraries/SparseCalldataSegmentLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,26 @@ library SparseCalldataSegmentLib {

/// @notice Splits out a segment of calldata, sparsely-packed.
/// The expected format is:
/// [uint32(len(segment0)), segment0, uint32(len(segment1)), segment1, ... uint32(len(segmentN)), segmentN]
/// [uint8(index0), uint32(len(segment0)), segment0, uint8(index1), uint32(len(segment1)), segment1,
/// ... uint8(indexN), uint32(len(segmentN)), segmentN]
/// @param source The calldata to extract the segment from.
/// @return segment The extracted segment. Using the above example, this would be segment0.
/// @return remainder The remaining calldata. Using the above example,
/// this would start at uint32(len(segment1)) and continue to the end at segmentN.
/// this would start at uint8(index1) and continue to the end at segmentN.
function getNextSegment(bytes calldata source)
internal
pure
returns (bytes calldata segment, bytes calldata remainder)
{
// The first 4 bytes hold the length of the segment, excluding the index.
uint32 length = uint32(bytes4(source[:4]));
// The first byte of the segment is the index.
// The next 4 bytes hold the length of the segment, excluding the index.
uint32 length = uint32(bytes4(source[1:5]));

// The offset of the remainder of the calldata.
uint256 remainderOffset = 4 + length;
uint256 remainderOffset = 5 + length;

// The segment is the next `length` + 1 bytes, to account for the index.
// By convention, the first byte of each segment is the index of the segment.
segment = source[4:remainderOffset];
// The segment is the next `length` bytes after the first 5 bytes.
segment = source[5:remainderOffset];

// The remainder is the rest of the calldata.
remainder = source[remainderOffset:];
Expand All @@ -52,7 +53,7 @@ library SparseCalldataSegmentLib {
pure
returns (bytes memory, bytes calldata)
{
uint8 nextIndex = peekIndex(source);
uint8 nextIndex = getIndex(source);

if (nextIndex < index) {
revert SegmentOutOfOrder();
Expand All @@ -61,8 +62,6 @@ library SparseCalldataSegmentLib {
if (nextIndex == index) {
(bytes calldata segment, bytes calldata remainder) = getNextSegment(source);

segment = getBody(segment);

if (segment.length == 0) {
revert NonCanonicalEncoding();
}
Expand All @@ -73,25 +72,16 @@ library SparseCalldataSegmentLib {
return ("", source);
}

/// @notice Extracts the final segment from the source.
/// @dev Reverts if the index of the segment is not RESERVED_VALIDATION_DATA_INDEX.
/// @param source The calldata to extract the segment from.
/// @return The final segment.
function getFinalSegment(bytes calldata source) internal pure returns (bytes calldata) {
(bytes calldata segment, bytes calldata remainder) = getNextSegment(source);

if (getIndex(segment) != RESERVED_VALIDATION_DATA_INDEX) {
if (getIndex(source) != RESERVED_VALIDATION_DATA_INDEX) {
revert ValidationSignatureSegmentMissing();
}

if (remainder.length != 0) {
revert NonCanonicalEncoding();
}

return getBody(segment);
}

/// @notice Returns the index of the next segment in the source.
/// @param source The calldata to extract the index from.
/// @return The index of the next segment.
function peekIndex(bytes calldata source) internal pure returns (uint8) {
return uint8(source[4]);
return source[1:];
}

/// @notice Extracts the index from a segment.
Expand All @@ -101,12 +91,4 @@ library SparseCalldataSegmentLib {
function getIndex(bytes calldata segment) internal pure returns (uint8) {
return uint8(segment[0]);
}

/// @notice Extracts the body from a segment.
/// @dev The body is the segment without the index.
/// @param segment The segment to extract the body from
/// @return The body of the segment.
function getBody(bytes calldata segment) internal pure returns (bytes calldata) {
return segment[1:];
}
}
27 changes: 17 additions & 10 deletions src/libraries/ValidationConfigLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@ pragma solidity ^0.8.20;

import {ModuleEntity, ValidationConfig} from "../interfaces/IModularAccount.sol";

// Validation flags layout:
// 0b00000___ // unused
// 0b_____A__ // isGlobal
// 0b______B_ // isSignatureValidation
// 0b_______C // isUserOpValidation
type ValidationFlags is uint8;

// Validation config is a packed representation of a validation function and flags for its configuration.
// Layout:
// 0xAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA________________________ // Address
Expand Down Expand Up @@ -63,22 +70,22 @@ library ValidationConfigLib {
function unpackUnderlying(ValidationConfig config)
internal
pure
returns (address _module, uint32 _entityId, uint8 flags)
returns (address _module, uint32 _entityId, ValidationFlags flags)
{
bytes25 configBytes = ValidationConfig.unwrap(config);
_module = address(bytes20(configBytes));
_entityId = uint32(bytes4(configBytes << 160));
flags = uint8(configBytes[24]);
flags = ValidationFlags.wrap(uint8(configBytes[24]));
}

function unpack(ValidationConfig config)
internal
pure
returns (ModuleEntity _validationFunction, uint8 flags)
returns (ModuleEntity _validationFunction, ValidationFlags flags)
{
bytes25 configBytes = ValidationConfig.unwrap(config);
_validationFunction = ModuleEntity.wrap(bytes24(configBytes));
flags = uint8(configBytes[24]);
flags = ValidationFlags.wrap(uint8(configBytes[24]));
}

function module(ValidationConfig config) internal pure returns (address) {
Expand All @@ -97,23 +104,23 @@ library ValidationConfigLib {
return ValidationConfig.unwrap(config) & _VALIDATION_FLAG_IS_GLOBAL != 0;
}

function isGlobal(uint8 flags) internal pure returns (bool) {
return flags & 0x04 != 0;
function isGlobal(ValidationFlags flags) internal pure returns (bool) {
return ValidationFlags.unwrap(flags) & 0x04 != 0;
}

function isSignatureValidation(ValidationConfig config) internal pure returns (bool) {
return ValidationConfig.unwrap(config) & _VALIDATION_FLAG_IS_SIGNATURE != 0;
}

function isSignatureValidation(uint8 flags) internal pure returns (bool) {
return flags & 0x02 != 0;
function isSignatureValidation(ValidationFlags flags) internal pure returns (bool) {
return ValidationFlags.unwrap(flags) & 0x02 != 0;
}

function isUserOpValidation(ValidationConfig config) internal pure returns (bool) {
return ValidationConfig.unwrap(config) & _VALIDATION_FLAG_IS_USER_OP != 0;
}

function isUserOpValidation(uint8 flags) internal pure returns (bool) {
return flags & 0x01 != 0;
function isUserOpValidation(ValidationFlags flags) internal pure returns (bool) {
return ValidationFlags.unwrap(flags) & 0x01 != 0;
}
}
2 changes: 1 addition & 1 deletion test/account/AccountReturnData.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ contract AccountReturnDataTest is AccountTestBase {

// Tests the ability to read data via executeWithRuntimeValidation
function test_returnData_authorized_exec() public {
bool result = ResultConsumerModule(address(account1)).checkResultExecuteWithAuthorization(
bool result = ResultConsumerModule(address(account1)).checkResultExecuteWithRuntimeValidation(
address(regularResultContract), keccak256("bar")
);

Expand Down
44 changes: 0 additions & 44 deletions test/account/PerHookData.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -246,34 +246,6 @@ contract PerHookDataTest is CustomValidationTestBase {
entryPoint.handleOps(userOps, beneficiary);
}

function test_failPerHookData_excessData_userOp() public {
(PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP();
(uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash());

PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)});

userOp.signature = abi.encodePacked(
_encodeSignature(
_signerValidation, GLOBAL_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v)
),
"extra data"
);

PackedUserOperation[] memory userOps = new PackedUserOperation[](1);
userOps[0] = userOp;

vm.expectRevert(
abi.encodeWithSelector(
IEntryPoint.FailedOpWithRevert.selector,
0,
"AA23 reverted",
abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector)
)
);
entryPoint.handleOps(userOps, beneficiary);
}

function test_passAccessControl_runtime() public {
assertEq(_counter.number(), 0);

Expand Down Expand Up @@ -420,22 +392,6 @@ contract PerHookDataTest is CustomValidationTestBase {
);
}

function test_failPerHookData_excessData_runtime() public {
PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)});

vm.prank(owner1);
vm.expectRevert(abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector));
account1.executeWithRuntimeValidation(
abi.encodeCall(
ReferenceModularAccount.execute, (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ()))
),
abi.encodePacked(
_encodeSignature(_signerValidation, GLOBAL_VALIDATION, preValidationHookData, ""), "extra data"
)
);
}

function test_pass1271AccessControl() public {
string memory message = "Hello, world!";

Expand Down
8 changes: 4 additions & 4 deletions test/libraries/SparseCalldataSegmentLib.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ contract SparseCalldataSegmentLibTest is Test {
bytes memory result = "";

for (uint256 i = 0; i < segments.length; i++) {
result = abi.encodePacked(result, uint32(segments[i].length), segments[i]);
result = abi.encodePacked(result, uint8(0), uint32(segments[i].length), segments[i]);
}

return result;
Expand All @@ -65,7 +65,7 @@ contract SparseCalldataSegmentLibTest is Test {
bytes memory result = "";

for (uint256 i = 0; i < segments.length; i++) {
result = abi.encodePacked(result, uint32(segments[i].length + 1), indices[i], segments[i]);
result = abi.encodePacked(result, indices[i], uint32(segments[i].length), segments[i]);
}

return result;
Expand Down Expand Up @@ -99,10 +99,10 @@ contract SparseCalldataSegmentLibTest is Test {

uint256 index = 0;
while (remainder.length > 0) {
indices[index] = remainder.getIndex();
bytes calldata segment;
(segment, remainder) = remainder.getNextSegment();
bodies[index] = segment.getBody();
indices[index] = segment.getIndex();
bodies[index] = segment;
index++;
}

Expand Down
12 changes: 7 additions & 5 deletions test/libraries/ValidationConfigLib.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ pragma solidity ^0.8.20;
import {Test} from "forge-std/Test.sol";

import {ModuleEntity, ModuleEntityLib} from "../../src/libraries/ModuleEntityLib.sol";
import {ValidationConfig, ValidationConfigLib} from "../../src/libraries/ValidationConfigLib.sol";
import {
ValidationConfig, ValidationConfigLib, ValidationFlags
} from "../../src/libraries/ValidationConfigLib.sol";

contract ValidationConfigLibTest is Test {
using ModuleEntityLib for ModuleEntity;
Expand All @@ -23,7 +25,7 @@ contract ValidationConfigLibTest is Test {
ValidationConfigLib.pack(module, entityId, isGlobal, isSignatureValidation, isUserOpValidation);

// Test unpacking underlying
(address module2, uint32 entityId2, uint8 flags2) = validationConfig.unpackUnderlying();
(address module2, uint32 entityId2, ValidationFlags flags2) = validationConfig.unpackUnderlying();

assertEq(module, module2, "module mismatch");
assertEq(entityId, entityId2, "entityId mismatch");
Expand All @@ -35,7 +37,7 @@ contract ValidationConfigLibTest is Test {

ModuleEntity expectedModuleEntity = ModuleEntityLib.pack(module, entityId);

(ModuleEntity validationFunction, uint8 flags3) = validationConfig.unpack();
(ModuleEntity validationFunction, ValidationFlags flags3) = validationConfig.unpack();

assertEq(
ModuleEntity.unwrap(validationFunction),
Expand Down Expand Up @@ -73,7 +75,7 @@ contract ValidationConfigLibTest is Test {

(address expectedModule, uint32 expectedEntityId) = validationFunction.unpack();

(address module, uint32 entityId, uint8 flags2) = validationConfig.unpackUnderlying();
(address module, uint32 entityId, ValidationFlags flags2) = validationConfig.unpackUnderlying();

assertEq(expectedModule, module, "module mismatch");
assertEq(expectedEntityId, entityId, "entityId mismatch");
Expand All @@ -83,7 +85,7 @@ contract ValidationConfigLibTest is Test {

// Test unpacking to ModuleEntity

(ModuleEntity validationFunction2, uint8 flags3) = validationConfig.unpack();
(ModuleEntity validationFunction2, ValidationFlags flags3) = validationConfig.unpack();

assertEq(
ModuleEntity.unwrap(validationFunction),
Expand Down
18 changes: 8 additions & 10 deletions test/mocks/modules/ReturnDataModuleMocks.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ pragma solidity ^0.8.20;

import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";

import {DIRECT_CALL_VALIDATION_ENTITYID} from "../../../src/helpers/Constants.sol";
import {
ExecutionManifest,
IExecutionModule,
ManifestExecutionFunction
} from "../../../src/interfaces/IExecutionModule.sol";

import {DIRECT_CALL_VALIDATION_ENTITYID} from "../../../src/helpers/Constants.sol";

import {IModularAccount} from "../../../src/interfaces/IModularAccount.sol";
import {IValidationModule} from "../../../src/interfaces/IValidationModule.sol";

import {ModuleEntityLib} from "../../../src/libraries/ModuleEntityLib.sol";
import {BaseModule} from "../../../src/modules/BaseModule.sol";

import {ModuleSignatureUtils} from "../../utils/ModuleSignatureUtils.sol";

contract RegularResultContract {
function foo() external pure returns (bytes32) {
return keccak256("bar");
Expand Down Expand Up @@ -62,7 +62,7 @@ contract ResultCreatorModule is IExecutionModule, BaseModule {
}
}

contract ResultConsumerModule is IExecutionModule, BaseModule, IValidationModule {
contract ResultConsumerModule is IExecutionModule, BaseModule, IValidationModule, ModuleSignatureUtils {
ResultCreatorModule public immutable RESULT_CREATOR;
RegularResultContract public immutable REGULAR_RESULT_CONTRACT;

Expand Down Expand Up @@ -102,13 +102,11 @@ contract ResultConsumerModule is IExecutionModule, BaseModule, IValidationModule
}

// Check the return data through the execute with authorization case
function checkResultExecuteWithAuthorization(address target, bytes32 expected) external returns (bool) {
function checkResultExecuteWithRuntimeValidation(address target, bytes32 expected) external returns (bool) {
// This result should be allowed based on the manifest permission request
bytes memory returnData = IModularAccount(msg.sender).executeWithRuntimeValidation(
abi.encodeCall(IModularAccount.execute, (target, 0, abi.encodeCall(RegularResultContract.foo, ()))),
abi.encodePacked(this, DIRECT_CALL_VALIDATION_ENTITYID, uint8(0), uint32(1), uint8(255)) // Validation
// function of self,
// selector-associated, with no auth data
_encodeSignature(ModuleEntityLib.pack(address(this), DIRECT_CALL_VALIDATION_ENTITYID), uint8(0), "")
);

bytes32 actual = abi.decode(abi.decode(returnData, (bytes)), (bytes32));
Expand All @@ -130,7 +128,7 @@ contract ResultConsumerModule is IExecutionModule, BaseModule, IValidationModule
allowGlobalValidation: false
});
manifest.executionFunctions[1] = ManifestExecutionFunction({
executionSelector: this.checkResultExecuteWithAuthorization.selector,
executionSelector: this.checkResultExecuteWithRuntimeValidation.selector,
skipRuntimeValidation: true,
allowGlobalValidation: false
});
Expand Down
Loading
Loading