diff --git a/src/helpers/TimestampAsserterLocator.sol b/src/helpers/TimestampAsserterLocator.sol new file mode 100644 index 00000000..68698a8e --- /dev/null +++ b/src/helpers/TimestampAsserterLocator.sol @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import "../interfaces/ITimestampAsserter.sol"; + +library TimestampAsserterLocator { + function locate() internal view returns (ITimestampAsserter) { + if (block.chainid == 260) { + return ITimestampAsserter(address(0x00000000000000000000000000000000808012)); + } + if (block.chainid == 300) { + revert("Timestamp asserter is not deployed on ZKsync Sepolia testnet yet"); + } + if (block.chainid == 324) { + revert("Timestamp asserter is not deployed on ZKsync mainnet yet"); + } + revert("Timestamp asserter is not deployed on this network"); + } +} diff --git a/src/interfaces/ITimestampAsserter.sol b/src/interfaces/ITimestampAsserter.sol new file mode 100644 index 00000000..132fc5c5 --- /dev/null +++ b/src/interfaces/ITimestampAsserter.sol @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +interface ITimestampAsserter { + function assertTimestampInRange(uint256 start, uint256 end) external view; +} diff --git a/src/libraries/SessionLib.sol b/src/libraries/SessionLib.sol index d3386db0..7bcf1a8a 100644 --- a/src/libraries/SessionLib.sol +++ b/src/libraries/SessionLib.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.24; import { Transaction } from "@matterlabs/zksync-contracts/l2/system-contracts/libraries/TransactionHelper.sol"; import { IPaymasterFlow } from "@matterlabs/zksync-contracts/l2/system-contracts/interfaces/IPaymasterFlow.sol"; +import { TimestampAsserterLocator } from "../helpers/TimestampAsserterLocator.sol"; library SessionLib { using SessionLib for SessionLib.Constraint; @@ -47,7 +48,7 @@ library SessionLib { mapping(address => uint256) lifetimeUsage; // Used for LimitType.Allowance // period => used that period - mapping(uint256 => mapping(address => uint256)) allowanceUsage; + mapping(uint64 => mapping(address => uint256)) allowanceUsage; } struct UsageLimit { @@ -116,20 +117,29 @@ library SessionLib { LimitState[] callParams; } - function checkAndUpdate(UsageLimit memory limit, UsageTracker storage tracker, uint256 value) internal { + function checkAndUpdate( + UsageLimit memory limit, + UsageTracker storage tracker, + uint256 value, + uint64 period + ) internal { if (limit.limitType == LimitType.Lifetime) { require(tracker.lifetimeUsage[msg.sender] + value <= limit.limit, "Lifetime limit exceeded"); tracker.lifetimeUsage[msg.sender] += value; } - // TODO: uncomment when it's possible to check timestamps during validation - // if (limit.limitType == LimitType.Allowance) { - // uint256 period = block.timestamp / limit.period; - // require(tracker.allowanceUsage[period] + value <= limit.limit); - // tracker.allowanceUsage[period] += value; - // } + if (limit.limitType == LimitType.Allowance) { + TimestampAsserterLocator.locate().assertTimestampInRange(period * limit.period, (period + 1) * limit.period); + require(tracker.allowanceUsage[period][msg.sender] + value <= limit.limit, "Allowance limit exceeded"); + tracker.allowanceUsage[period][msg.sender] += value; + } } - function checkAndUpdate(Constraint memory constraint, UsageTracker storage tracker, bytes calldata data) internal { + function checkAndUpdate( + Constraint memory constraint, + UsageTracker storage tracker, + bytes calldata data, + uint64 period + ) internal { uint256 index = 4 + constraint.index * 32; bytes32 param = bytes32(data[index:index + 32]); Condition condition = constraint.condition; @@ -149,18 +159,29 @@ library SessionLib { require(param != refValue, "NOT_EQUAL constraint not met"); } - constraint.limit.checkAndUpdate(tracker, uint256(param)); + constraint.limit.checkAndUpdate(tracker, uint256(param), period); } - function validate(SessionStorage storage state, Transaction calldata transaction, SessionSpec memory spec) internal { - require(state.status[msg.sender] == Status.Active, "Session is not active"); + function validate( + SessionStorage storage state, + Transaction calldata transaction, + SessionSpec memory spec, + uint64[] memory periodIds + ) internal { + // Here we additionally pass uint64[] periodId to check allowance limits + // periodId is defined as block.timestamp / limit.period if limitType == Allowance, and 0 otherwise (which will be ignored). + // periodIds[0] is for fee limit, + // periodIds[1] is for value limit, + // periodIds[2:] are for call constraints, if there are any. + // It is required to pass them in (instead of computing via block.timestamp) since during validation + // we can only assert the range of the timestamp, but not access its value. - // TODO uncomment when it's possible to check timestamps during validation - // require(block.timestamp <= session.expiresAt); + require(state.status[msg.sender] == Status.Active, "Session is not active"); + TimestampAsserterLocator.locate().assertTimestampInRange(0, spec.expiresAt); // TODO: update fee allowance with the gasleft/refund at the end of execution uint256 fee = transaction.maxFeePerGas * transaction.gasLimit; - spec.feeLimit.checkAndUpdate(state.fee, fee); + spec.feeLimit.checkAndUpdate(state.fee, fee, periodIds[0]); address target = address(uint160(transaction.to)); @@ -185,12 +206,12 @@ library SessionLib { } } - require(found, "Call not allowed"); + require(found, "Call to this contract is not allowed"); require(transaction.value <= callPolicy.maxValuePerUse, "Value exceeds limit"); - callPolicy.valueLimit.checkAndUpdate(state.callValue[target][selector], transaction.value); + callPolicy.valueLimit.checkAndUpdate(state.callValue[target][selector], transaction.value, periodIds[1]); for (uint256 i = 0; i < callPolicy.constraints.length; i++) { - callPolicy.constraints[i].checkAndUpdate(state.params[target][selector][i], transaction.data); + callPolicy.constraints[i].checkAndUpdate(state.params[target][selector][i], transaction.data, periodIds[i + 2]); } } else { TransferSpec memory transferPolicy; @@ -204,9 +225,9 @@ library SessionLib { } } - require(found, "Transfer not allowed"); + require(found, "Transfer to this address is not allowed"); require(transaction.value <= transferPolicy.maxValuePerUse, "Value exceeds limit"); - transferPolicy.valueLimit.checkAndUpdate(state.transferValue[target], transaction.value); + transferPolicy.valueLimit.checkAndUpdate(state.transferValue[target], transaction.value, periodIds[1]); } } @@ -224,7 +245,7 @@ library SessionLib { } if (limit.limitType == LimitType.Allowance) { // this is not used during validation, so it's fine to use block.timestamp - uint256 period = block.timestamp / limit.period; + uint64 period = uint64(block.timestamp / limit.period); return limit.limit - tracker.allowanceUsage[period][account]; } } diff --git a/src/validators/SessionKeyValidator.sol b/src/validators/SessionKeyValidator.sol index 7609b5bf..261b2636 100644 --- a/src/validators/SessionKeyValidator.sol +++ b/src/validators/SessionKeyValidator.sol @@ -139,11 +139,14 @@ contract SessionKeyValidator is IValidationHook, IModuleValidator, IModule { // This transaction is not meant to be validated by this module return; } - SessionLib.SessionSpec memory spec = abi.decode(hookData, (SessionLib.SessionSpec)); + (SessionLib.SessionSpec memory spec, uint64[] memory periodIds) = abi.decode( + hookData, + (SessionLib.SessionSpec, uint64[]) + ); (address recoveredAddress, ) = ECDSA.tryRecover(signedHash, signature); require(recoveredAddress == spec.signer, "Invalid signer"); bytes32 sessionHash = keccak256(abi.encode(spec)); - sessions[sessionHash].validate(transaction, spec); + sessions[sessionHash].validate(transaction, spec, periodIds); // Set the validation result to 1 for this hash, so that isValidSignature succeeds uint256 slot = uint256(signedHash); diff --git a/test/SessionKeyTest.ts b/test/SessionKeyTest.ts index f15018f1..9ae27db9 100644 --- a/test/SessionKeyTest.ts +++ b/test/SessionKeyTest.ts @@ -15,6 +15,77 @@ const fixtures = new ContractFixtures(); const abiCoder = new ethers.AbiCoder(); const provider = getProvider(); +const sessionSpecAbi = ethers.ParamType.from({ + components: [ + { name: "signer", type: "address" }, + { name: "expiresAt", type: "uint256" }, + { + components: [ + { name: "limitType", type: "uint8" }, + { name: "limit", type: "uint256" }, + { name: "period", type: "uint256" }, + ], + name: "feeLimit", + type: "tuple", + }, + { + components: [ + { name: "target", type: "address" }, + { name: "selector", type: "bytes4" }, + { name: "maxValuePerUse", type: "uint256" }, + { + components: [ + { name: "limitType", type: "uint8" }, + { name: "limit", type: "uint256" }, + { name: "period", type: "uint256" }, + ], + name: "valueLimit", + type: "tuple", + }, + { + components: [ + { name: "condition", type: "uint8" }, + { name: "index", type: "uint64" }, + { name: "refValue", type: "bytes32" }, + { + components: [ + { name: "limitType", type: "uint8" }, + { name: "limit", type: "uint256" }, + { name: "period", type: "uint256" }, + ], + name: "limit", + type: "tuple", + }, + ], + name: "constraints", + type: "tuple[]", + }, + ], + name: "callPolicies", + type: "tuple[]", + }, + { + components: [ + { name: "target", type: "address" }, + { name: "maxValuePerUse", type: "uint256" }, + { + components: [ + { name: "limitType", type: "uint8" }, + { name: "limit", type: "uint256" }, + { name: "period", type: "uint256" }, + ], + name: "valueLimit", + type: "tuple", + }, + ], + name: "transferPolicies", + type: "tuple[]", + }, + ], + name: "sessionSpec", + type: "tuple", +}); + enum Condition { Unconstrained = 0, Equal = 1, @@ -58,10 +129,20 @@ type PartialSession = { }[]; }; +async function getTimestamp() { + if (hre.network.name == "inMemoryNode") { + return Math.floor(await provider.send("config_getCurrentTimestamp", [])); + } else { + return Math.floor(Date.now() / 1000); + } +} + class SessionTester { public sessionOwner: Wallet; public session: SessionLib.SessionSpecStruct; public sessionAccount: SmartAccount; + // having this is a bit hacky, but it's so we can provide correct period ids in the signature + aaTransaction: ethers.TransactionLike; constructor(public proxyAccountAddress: string, sessionKeyModuleAddress: string) { this.sessionOwner = new Wallet(Wallet.createRandom().privateKey, provider); @@ -71,7 +152,10 @@ class SessionTester { [ this.sessionOwner.signingKey.sign(hash).serialized, sessionKeyModuleAddress, - [await this.encodeSession()], // this array supplies data for hooks + [abiCoder.encode( + [sessionSpecAbi, "uint64[]"], + [this.session, await this.periodIds(this.aaTransaction.to!, this.aaTransaction.data?.slice(0, 10))], + )], // this array supplies data for hooks ], ), address: this.proxyAccountAddress, @@ -107,9 +191,35 @@ class SessionTester { expect(newState.status).to.equal(1, "session should be active"); } - async encodeSession() { - const sessionKeyModuleContract = await fixtures.getSessionKeyContract(); - return "0x" + sessionKeyModuleContract.interface.encodeFunctionData("createSession", [this.session]).slice(10); + encodeSession() { + return abiCoder.encode([sessionSpecAbi], [this.session]); + } + + async periodIds(target: string, selector?: string) { + const timestamp = await getTimestamp(); + + const getId = (limit: SessionLib.UsageLimitStruct) => { + if (limit.limitType == LimitType.Allowance) { + return Math.floor(timestamp / Number(limit.period)); + } + return 0; + }; + + const isTransfer = selector == null || ethers.getBytes(selector).length < 4; + const policy: SessionLib.CallSpecStruct | SessionLib.TransferSpecStruct | undefined = isTransfer + ? this.session.transferPolicies.find((policy) => policy.target == target) + : this.session.callPolicies.find((policy) => policy.target == target && ethers.hexlify(policy.selector) == selector); + + if (policy == null) { + throw new Error("Transaction does not fit any policy"); + } + + const periodIds = [ + getId(this.session.feeLimit), + getId(policy.valueLimit), + ...(isTransfer ? [] : (policy).constraints.map((constraint) => getId(constraint.limit))), + ]; + return periodIds; } async revokeKey() { @@ -122,7 +232,7 @@ class SessionTester { secret: fixtures.wallet.privateKey, }, provider); - const sessionHash = ethers.keccak256(await this.encodeSession()); + const sessionHash = ethers.keccak256(this.encodeSession()); const aaTx = { ...await this.aaTxTemplate(), @@ -138,28 +248,29 @@ class SessionTester { expect(newState.status).to.equal(2, "session should be revoked"); } - async sendTxSuccess(txRequest: ethers.TransactionRequest = {}) { - const aaTx = { + async sendTxSuccess(txRequest: ethers.TransactionLike = {}) { + this.aaTransaction = { ...await this.aaTxTemplate(), ...txRequest, }; - aaTx.gasLimit = await provider.estimateGas(aaTx); - logInfo(`\`sessionTx\` gas estimated: ${await provider.estimateGas(aaTx)}`); + const estimatedGas = await provider.estimateGas(this.aaTransaction); + this.aaTransaction.gasLimit = BigInt(Math.ceil(Number(estimatedGas) * 1.1)); + logInfo(`\`sessionTx\` gas estimated: ${await provider.estimateGas(this.aaTransaction)}`); - const signedTransaction = await this.sessionAccount.signTransaction(aaTx); + const signedTransaction = await this.sessionAccount.signTransaction(this.aaTransaction); const tx = await provider.broadcastTransaction(signedTransaction); const receipt = await tx.wait(); logInfo(`\`sessionTx\` gas used: ${receipt.gasUsed}`); } - async sendTxFail(tx: ethers.TransactionRequest = {}) { - const aaTx = { + async sendTxFail(tx: ethers.TransactionLike = {}) { + this.aaTransaction = { ...await this.aaTxTemplate(), gasLimit: 100_000_000n, ...tx, }; - const signedTransaction = await this.sessionAccount.signTransaction(aaTx); + const signedTransaction = await this.sessionAccount.signTransaction(this.aaTransaction); await expect(provider.broadcastTransaction(signedTransaction)).to.be.reverted; }; @@ -210,6 +321,7 @@ class SessionTester { } async aaTxTemplate() { + const numberOfConstraints = this.session.callPolicies.map((policy) => policy.constraints.length); return { type: 113, from: this.proxyAccountAddress, @@ -225,7 +337,10 @@ class SessionTester { [ ethers.zeroPadValue("0x1b", 65), await fixtures.getSessionKeyModuleAddress(), - [await this.encodeSession()], + [abiCoder.encode( + [sessionSpecAbi, "uint64[]"], + [this.session, new Array(2 + Math.max(0, ...numberOfConstraints)).fill(0)], + )], ], ), }, @@ -263,7 +378,6 @@ describe("SessionKeyModule tests", function () { logInfo(`Session Address : ${await sessionModuleContract.getAddress()}`); logInfo(`Passkey Address : ${await verifierContract.getAddress()}`); logInfo(`Account Factory Address : ${await factoryContract.getAddress()}`); - // logInfo(`Account Implementation Address : ${await ssoContract.getAddress()}`); logInfo(`Auth Server Paymaster Address : ${await authServerPaymaster.getAddress()}`); }); @@ -293,7 +407,7 @@ describe("SessionKeyModule tests", function () { assert(await account.isModuleValidator(sessionKeyModuleAddress), "session key module should be a validator"); }); - describe("Value transfer limit test", function () { + describe("Value transfer limit tests", function () { let tester: SessionTester; const sessionTarget = Wallet.createRandom().address; @@ -324,7 +438,7 @@ describe("SessionKeyModule tests", function () { }); }); - describe("ERC20 transfer limit", function () { + describe("ERC20 transfer limit tests", function () { let tester: SessionTester; let erc20: ERC20; const sessionTarget = Wallet.createRandom().address; @@ -406,6 +520,66 @@ describe("SessionKeyModule tests", function () { }); }); + (hre.network.name == "inMemoryNode" ? describe : describe.skip)("Timestamp-based tests", function () { + let tester: SessionTester; + const sessionTarget = Wallet.createRandom().address; + const period = 120; + + it("should create a session", async () => { + tester = new SessionTester(proxyAccountAddress, await fixtures.getSessionKeyModuleAddress()); + await tester.createSession({ + expiresAt: await getTimestamp() + period * 3, + transferPolicies: [{ + target: sessionTarget, + maxValuePerUse: parseEther("0.01"), + valueLimit: { + limit: parseEther("0.015"), + period, + }, + }], + }); + }); + + it("should use a session key to send a transaction", async () => { + // We have to wait until the next period starts + const timestamp = Math.floor(await provider.send("config_getCurrentTimestamp", [])); + await provider.send("evm_increaseTime", [period - (timestamp % period)]); + // NOTE: this only works because `period` is > 60 seconds, since the default is + // `timestamp_asserter.min_time_till_end_sec = 60` + // We can sidestep the waiting by calling `evm_increaseTime` directly during the test, + // but also meaans that in production, creating sessions that expire in < 60 seconds is useless. + // Same goes for allowance time periods with duration < 60 seconds. + await tester.sendTxSuccess({ + to: sessionTarget, + value: parseEther("0.01"), + }); + }); + + it("should reject a transaction that goes over allowance limit", async () => { + await tester.sendTxFail({ + to: sessionTarget, + value: parseEther("0.01"), + }); + }); + + it("should wait until allowance renews and send a transaction", async () => { + await provider.send("evm_increaseTime", [period]); + await tester.sendTxSuccess({ + to: sessionTarget, + value: parseEther("0.01"), + }); + }); + + // TODO: check error messages as well + it("should reject a transaction with an expired session", async () => { + await provider.send("evm_increaseTime", [period * 2]); + await tester.sendTxFail({ + to: sessionTarget, + value: parseEther("0.01"), + }); + }); + }); + describe("Upgrade tests", function () { let beacon: AAFactory;