From 4922344fe60cf72cdc7f55b59bdc52fefb025201 Mon Sep 17 00:00:00 2001 From: Augustus Chang Date: Mon, 30 Sep 2024 17:06:43 -0400 Subject: [PATCH] address comments --- .../src/access_control/rbac_timelock.cairo | 11 ++- contracts/src/libraries/enumerable_set.cairo | 4 + contracts/src/tests/test_enumerable_set.cairo | 71 +++++++++++++++++- contracts/src/tests/test_rbac_timelock.cairo | 73 ++++++++++++++++--- 4 files changed, 137 insertions(+), 22 deletions(-) diff --git a/contracts/src/access_control/rbac_timelock.cairo b/contracts/src/access_control/rbac_timelock.cairo index c709003da..5fe078996 100644 --- a/contracts/src/access_control/rbac_timelock.cairo +++ b/contracts/src/access_control/rbac_timelock.cairo @@ -373,6 +373,10 @@ mod RBACTimelock { } } + // + // ONLY ADMIN + // + fn update_delay(ref self: ContractState, new_delay: u256) { self.access_control.assert_only_role(ADMIN_ROLE); @@ -387,10 +391,6 @@ mod RBACTimelock { self._min_delay.write(new_delay); } - // - // ONLY ADMIN - // - fn block_function_selector(ref self: ContractState, selector: felt252) { self.access_control.assert_only_role(ADMIN_ROLE); @@ -493,8 +493,7 @@ mod RBACTimelock { } fn _execute(ref self: ContractState, call: Call) { - let _response = call_contract_syscall(call.target, call.selector, call.data) - .unwrap_syscall(); + call_contract_syscall(call.target, call.selector, call.data).unwrap_syscall(); } } } diff --git a/contracts/src/libraries/enumerable_set.cairo b/contracts/src/libraries/enumerable_set.cairo index 54e0b024b..791b35606 100644 --- a/contracts/src/libraries/enumerable_set.cairo +++ b/contracts/src/libraries/enumerable_set.cairo @@ -2,6 +2,7 @@ mod EnumerableSetComponent { use core::array::ArrayTrait; + // set is 1-indexed, not 0-indexed #[storage] pub struct Storage { // access index by value @@ -11,6 +12,7 @@ mod EnumerableSetComponent { // access value by index // set_id -> item_id -> item_value // note: item_index is +1 because 0 means item is not in set + // note: _values.read(set_id, item_id) == 0, is only valid iff item_id <= _length.read(set_id) pub _values: LegacyMap::<(u256, u256), u256>, // set_id -> size of set pub _length: LegacyMap @@ -77,6 +79,8 @@ mod EnumerableSetComponent { } fn at(self: @ComponentState, set_id: u256, index: u256) -> u256 { + assert(index != 0, 'set is 1-indexed'); + assert(index <= self._length.read(set_id), 'index out of bounds'); self._values.read((set_id, index)) } diff --git a/contracts/src/tests/test_enumerable_set.cairo b/contracts/src/tests/test_enumerable_set.cairo index b5d23ef8f..1f4e008b8 100644 --- a/contracts/src/tests/test_enumerable_set.cairo +++ b/contracts/src/tests/test_enumerable_set.cairo @@ -9,6 +9,24 @@ use snforge_std::{declare, ContractClassTrait}; const MOCK_SET_ID: u256 = 'adfasdf'; const OTHER_SET_ID: u256 = 'fakeasdf'; +fn expect_out_of_bounds>(result: Result>) { + match result { + Result::Ok(_) => panic!("expect 'index out of bounds'"), + Result::Err(panic_data) => { + assert(*panic_data.at(0) == 'index out of bounds', *panic_data.at(0)); + } + } +} + +fn expect_set_is_1_indexed>(result: Result>) { + match result { + Result::Ok(_) => panic!("expect 'set is 1-indexed'"), + Result::Err(panic_data) => { + assert(*panic_data.at(0) == 'set is 1-indexed', *panic_data.at(0)); + } + } +} + fn setup_mock() -> ( ContractAddress, IMockEnumerableSetDispatcher, IMockEnumerableSetSafeDispatcher ) { @@ -56,8 +74,9 @@ fn test_add() { } #[test] +#[feature("safe_dispatcher")] fn test_remove() { - let (_, mock, _) = setup_mock(); + let (_, mock, safe_mock) = setup_mock(); let first_value = 12; // ensure that removing other sets do not interfere with current set @@ -90,7 +109,7 @@ fn test_remove() { mock.contains(MOCK_SET_ID, 100) && mock.contains(MOCK_SET_ID, 200), 'contains 100 & 200' ); assert(mock.at(MOCK_SET_ID, 1) == 100 && mock.at(MOCK_SET_ID, 2) == 200, 'indexes match'); - assert(mock.at(MOCK_SET_ID, 3) == 0, 'no entry at 3rd index'); + expect_out_of_bounds(safe_mock.at(MOCK_SET_ID, 3)); assert(mock.values(MOCK_SET_ID) == array![100, 200], 'values should match'); // [100, 200, 300] @@ -104,7 +123,7 @@ fn test_remove() { mock.contains(MOCK_SET_ID, 300) && mock.contains(MOCK_SET_ID, 200), 'contains 300 & 200' ); assert(mock.at(MOCK_SET_ID, 1) == 300 && mock.at(MOCK_SET_ID, 2) == 200, 'indexes match'); - assert(mock.at(MOCK_SET_ID, 3) == 0, 'no entry at 3rd index'); + expect_out_of_bounds(safe_mock.at(MOCK_SET_ID, 3)); assert(mock.values(MOCK_SET_ID) == array![300, 200], 'values should match'); // [200] @@ -113,7 +132,7 @@ fn test_remove() { assert(!mock.contains(MOCK_SET_ID, 300), 'does not contain 300'); assert(mock.contains(MOCK_SET_ID, 200), 'contains 200'); assert(mock.at(MOCK_SET_ID, 1) == 200, 'indexes match'); - assert(mock.at(MOCK_SET_ID, 2) == 0, 'no entry at 2nd index'); + expect_out_of_bounds(safe_mock.at(MOCK_SET_ID, 2)); assert(mock.values(MOCK_SET_ID) == array![200], 'values should match'); // [] @@ -149,3 +168,47 @@ fn test_length() { assert(mock.length(OTHER_SET_ID) == 0, 'should be 0'); } +#[test] +#[feature("safe_dispatcher")] +fn test_zero() { + let (_, mock, safe_mock) = setup_mock(); + + expect_set_is_1_indexed(safe_mock.at(MOCK_SET_ID, 0)); + + // [0] + assert(mock.add(MOCK_SET_ID, 0), 'should add 0'); + assert(mock.contains(MOCK_SET_ID, 0), 'contains 0'); + + assert(mock.length(MOCK_SET_ID) == 1, 'should be 1'); + + // [0, 1] + assert(mock.add(MOCK_SET_ID, 1), 'should add 1'); + assert(!mock.add(MOCK_SET_ID, 1), 'shouldnt add 1'); + + assert(mock.length(MOCK_SET_ID) == 2, 'should be 2'); + + assert(mock.at(MOCK_SET_ID, 1) == 0, 'set[1] = 0'); + assert(mock.at(MOCK_SET_ID, 2) == 1, 'set[2] = 0'); + + // [1] + assert(mock.remove(MOCK_SET_ID, 0), 'should remove 0'); + assert(!mock.remove(MOCK_SET_ID, 0), 'shouldnt remove 0'); + + assert(mock.at(MOCK_SET_ID, 1) == 1, 'set[1] = 1'); + assert(!mock.contains(MOCK_SET_ID, 0), '0 is gone'); + assert(mock.length(MOCK_SET_ID) == 1, 'length 1'); + + // [] + assert(mock.remove(MOCK_SET_ID, 1), '1 removed'); + + // [0] + mock.add(MOCK_SET_ID, 0); + + assert(mock.at(MOCK_SET_ID, 1) == 0, 'set[1] = 0'); + + // [] + mock.remove(MOCK_SET_ID, 0); + + expect_out_of_bounds(safe_mock.at(MOCK_SET_ID, 1)); +} + diff --git a/contracts/src/tests/test_rbac_timelock.cairo b/contracts/src/tests/test_rbac_timelock.cairo index ad3d4a950..dd29ccc62 100644 --- a/contracts/src/tests/test_rbac_timelock.cairo +++ b/contracts/src/tests/test_rbac_timelock.cairo @@ -10,7 +10,6 @@ use chainlink::{ IMockMultisigTarget, IMockMultisigTargetDispatcherTrait, IMockMultisigTargetDispatcher } }; - use openzeppelin::{ introspection::interface::{ISRC5, ISRC5Dispatcher, ISRC5DispatcherTrait, ISRC5_ID}, access::accesscontrol::{ @@ -22,6 +21,7 @@ use openzeppelin::{ }, token::{erc1155::interface::{IERC1155_RECEIVER_ID}, erc721::interface::{IERC721_RECEIVER_ID}} }; +use chainlink::tests::test_enumerable_set::{expect_out_of_bounds, expect_set_is_1_indexed}; use snforge_std::{ declare, ContractClassTrait, spy_events, EventSpyAssertionsTrait, start_cheat_caller_address_global, start_cheat_block_timestamp_global @@ -590,6 +590,55 @@ fn test_execute_successful() { assert(actual_toggle, 'toggle true'); assert(timelock.is_operation_done(id), 'operation is done'); + + // let's try to schedule another batch of operations using the predecessor + + let mock_time = 3000; + let mock_ready_time = mock_time + min_delay.try_into().unwrap(); + + let call_3 = Call { + target: target_address, selector: selector!("flip_toggle"), data: array![].span() + }; + let calls = array![call_3].span(); + let predecessor = id; + let salt = 2; + + start_cheat_caller_address_global(proposer); + start_cheat_block_timestamp_global(mock_time); + + timelock.schedule_batch(calls, predecessor, salt, min_delay); + + let id = timelock.hash_operation_batch(calls, predecessor, salt); + + start_cheat_caller_address_global(executor); + start_cheat_block_timestamp_global(mock_ready_time); + + let mut spy = spy_events(); + + timelock.execute_batch(calls, predecessor, salt); + + spy + .assert_emitted( + @array![ + ( + timelock_address, + RBACTimelock::Event::CallExecuted( + RBACTimelock::CallExecuted { + id: id, + index: 0, + target: call_3.target, + selector: call_3.selector, + data: call_3.data + } + ) + ) + ] + ); + + let (_, actual_toggle) = target.read(); + assert(!actual_toggle, 'toggle went to false again'); + + assert(timelock.is_operation_done(id), 'operation is done'); } #[test] @@ -761,9 +810,10 @@ fn test_unblock_selector() { } #[test] +#[feature("safe_dispatcher")] fn test_blocked_selector_indexes() { let (_, admin, _, _, _, _) = deploy_args(); - let (_, timelock, _) = setup_timelock(); + let (_, timelock, safe_timelock) = setup_timelock(); start_cheat_caller_address_global(admin); @@ -775,12 +825,14 @@ fn test_blocked_selector_indexes() { timelock.block_function_selector(selector2); timelock.block_function_selector(selector3); + expect_out_of_bounds(safe_timelock.get_blocked_function_selector_at(5)); + expect_set_is_1_indexed(safe_timelock.get_blocked_function_selector_at(0)); + // [selector1, selector2, selector3] assert(timelock.get_blocked_function_selector_count() == 3, 'count is 3'); assert(timelock.get_blocked_function_selector_at(1) == selector1, 'selector 1'); assert(timelock.get_blocked_function_selector_at(2) == selector2, 'selector 2'); assert(timelock.get_blocked_function_selector_at(3) == selector3, 'selector 3'); - assert(timelock.get_blocked_function_selector_at(0) == 0, 'no selector'); timelock.unblock_function_selector(selector1); @@ -788,25 +840,22 @@ fn test_blocked_selector_indexes() { assert(timelock.get_blocked_function_selector_count() == 2, 'count is 2'); assert(timelock.get_blocked_function_selector_at(1) == selector3, 'selector 3'); assert(timelock.get_blocked_function_selector_at(2) == selector2, 'selector 2'); - assert(timelock.get_blocked_function_selector_at(3) == 0, 'selector 3'); - assert(timelock.get_blocked_function_selector_at(0) == 0, 'no selector'); + expect_out_of_bounds(safe_timelock.get_blocked_function_selector_at(3)); timelock.unblock_function_selector(selector2); // [selector3] assert(timelock.get_blocked_function_selector_count() == 1, 'count is 1'); assert(timelock.get_blocked_function_selector_at(1) == selector3, 'selector 3'); - assert(timelock.get_blocked_function_selector_at(2) == 0, 'no selector'); - assert(timelock.get_blocked_function_selector_at(3) == 0, 'no selector'); - assert(timelock.get_blocked_function_selector_at(0) == 0, 'no selector'); + expect_out_of_bounds(safe_timelock.get_blocked_function_selector_at(2)); + expect_out_of_bounds(safe_timelock.get_blocked_function_selector_at(3)); timelock.unblock_function_selector(selector3); assert(timelock.get_blocked_function_selector_count() == 0, 'count is 0'); - assert(timelock.get_blocked_function_selector_at(1) == 0, 'no selector'); - assert(timelock.get_blocked_function_selector_at(2) == 0, 'no selector'); - assert(timelock.get_blocked_function_selector_at(3) == 0, 'no selector'); - assert(timelock.get_blocked_function_selector_at(0) == 0, 'no selector'); + expect_out_of_bounds(safe_timelock.get_blocked_function_selector_at(1)); + expect_out_of_bounds(safe_timelock.get_blocked_function_selector_at(2)); + expect_out_of_bounds(safe_timelock.get_blocked_function_selector_at(3)); } #[test]