Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
augustbleeds committed Sep 30, 2024
1 parent b0ba9ec commit 4922344
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 22 deletions.
11 changes: 5 additions & 6 deletions contracts/src/access_control/rbac_timelock.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,10 @@ mod RBACTimelock {
}
}

//
// ONLY ADMIN
//

fn update_delay(ref self: ContractState, new_delay: u256) {
self.access_control.assert_only_role(ADMIN_ROLE);

Expand All @@ -387,10 +391,6 @@ mod RBACTimelock {
self._min_delay.write(new_delay);
}

//
// ONLY ADMIN
//

fn block_function_selector(ref self: ContractState, selector: felt252) {
self.access_control.assert_only_role(ADMIN_ROLE);

Expand Down Expand Up @@ -493,8 +493,7 @@ mod RBACTimelock {
}

fn _execute(ref self: ContractState, call: Call) {
let _response = call_contract_syscall(call.target, call.selector, call.data)
.unwrap_syscall();
call_contract_syscall(call.target, call.selector, call.data).unwrap_syscall();
}
}
}
4 changes: 4 additions & 0 deletions contracts/src/libraries/enumerable_set.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod EnumerableSetComponent {
use core::array::ArrayTrait;

// set is 1-indexed, not 0-indexed
#[storage]
pub struct Storage {
// access index by value
Expand All @@ -11,6 +12,7 @@ mod EnumerableSetComponent {
// access value by index
// set_id -> item_id -> item_value
// note: item_index is +1 because 0 means item is not in set
// note: _values.read(set_id, item_id) == 0, is only valid iff item_id <= _length.read(set_id)
pub _values: LegacyMap::<(u256, u256), u256>,
// set_id -> size of set
pub _length: LegacyMap<u256, u256>
Expand Down Expand Up @@ -77,6 +79,8 @@ mod EnumerableSetComponent {
}

fn at(self: @ComponentState<TContractState>, set_id: u256, index: u256) -> u256 {
assert(index != 0, 'set is 1-indexed');
assert(index <= self._length.read(set_id), 'index out of bounds');
self._values.read((set_id, index))
}

Expand Down
71 changes: 67 additions & 4 deletions contracts/src/tests/test_enumerable_set.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,24 @@ use snforge_std::{declare, ContractClassTrait};
const MOCK_SET_ID: u256 = 'adfasdf';
const OTHER_SET_ID: u256 = 'fakeasdf';

fn expect_out_of_bounds<T, impl TDrop: Drop<T>>(result: Result<T, Array<felt252>>) {
match result {
Result::Ok(_) => panic!("expect 'index out of bounds'"),
Result::Err(panic_data) => {
assert(*panic_data.at(0) == 'index out of bounds', *panic_data.at(0));
}
}
}

fn expect_set_is_1_indexed<T, impl TDrop: Drop<T>>(result: Result<T, Array<felt252>>) {
match result {
Result::Ok(_) => panic!("expect 'set is 1-indexed'"),
Result::Err(panic_data) => {
assert(*panic_data.at(0) == 'set is 1-indexed', *panic_data.at(0));
}
}
}

fn setup_mock() -> (
ContractAddress, IMockEnumerableSetDispatcher, IMockEnumerableSetSafeDispatcher
) {
Expand Down Expand Up @@ -56,8 +74,9 @@ fn test_add() {
}

#[test]
#[feature("safe_dispatcher")]
fn test_remove() {
let (_, mock, _) = setup_mock();
let (_, mock, safe_mock) = setup_mock();
let first_value = 12;

// ensure that removing other sets do not interfere with current set
Expand Down Expand Up @@ -90,7 +109,7 @@ fn test_remove() {
mock.contains(MOCK_SET_ID, 100) && mock.contains(MOCK_SET_ID, 200), 'contains 100 & 200'
);
assert(mock.at(MOCK_SET_ID, 1) == 100 && mock.at(MOCK_SET_ID, 2) == 200, 'indexes match');
assert(mock.at(MOCK_SET_ID, 3) == 0, 'no entry at 3rd index');
expect_out_of_bounds(safe_mock.at(MOCK_SET_ID, 3));
assert(mock.values(MOCK_SET_ID) == array![100, 200], 'values should match');

// [100, 200, 300]
Expand All @@ -104,7 +123,7 @@ fn test_remove() {
mock.contains(MOCK_SET_ID, 300) && mock.contains(MOCK_SET_ID, 200), 'contains 300 & 200'
);
assert(mock.at(MOCK_SET_ID, 1) == 300 && mock.at(MOCK_SET_ID, 2) == 200, 'indexes match');
assert(mock.at(MOCK_SET_ID, 3) == 0, 'no entry at 3rd index');
expect_out_of_bounds(safe_mock.at(MOCK_SET_ID, 3));
assert(mock.values(MOCK_SET_ID) == array![300, 200], 'values should match');

// [200]
Expand All @@ -113,7 +132,7 @@ fn test_remove() {
assert(!mock.contains(MOCK_SET_ID, 300), 'does not contain 300');
assert(mock.contains(MOCK_SET_ID, 200), 'contains 200');
assert(mock.at(MOCK_SET_ID, 1) == 200, 'indexes match');
assert(mock.at(MOCK_SET_ID, 2) == 0, 'no entry at 2nd index');
expect_out_of_bounds(safe_mock.at(MOCK_SET_ID, 2));
assert(mock.values(MOCK_SET_ID) == array![200], 'values should match');

// []
Expand Down Expand Up @@ -149,3 +168,47 @@ fn test_length() {
assert(mock.length(OTHER_SET_ID) == 0, 'should be 0');
}

#[test]
#[feature("safe_dispatcher")]
fn test_zero() {
let (_, mock, safe_mock) = setup_mock();

expect_set_is_1_indexed(safe_mock.at(MOCK_SET_ID, 0));

// [0]
assert(mock.add(MOCK_SET_ID, 0), 'should add 0');
assert(mock.contains(MOCK_SET_ID, 0), 'contains 0');

assert(mock.length(MOCK_SET_ID) == 1, 'should be 1');

// [0, 1]
assert(mock.add(MOCK_SET_ID, 1), 'should add 1');
assert(!mock.add(MOCK_SET_ID, 1), 'shouldnt add 1');

assert(mock.length(MOCK_SET_ID) == 2, 'should be 2');

assert(mock.at(MOCK_SET_ID, 1) == 0, 'set[1] = 0');
assert(mock.at(MOCK_SET_ID, 2) == 1, 'set[2] = 0');

// [1]
assert(mock.remove(MOCK_SET_ID, 0), 'should remove 0');
assert(!mock.remove(MOCK_SET_ID, 0), 'shouldnt remove 0');

assert(mock.at(MOCK_SET_ID, 1) == 1, 'set[1] = 1');
assert(!mock.contains(MOCK_SET_ID, 0), '0 is gone');
assert(mock.length(MOCK_SET_ID) == 1, 'length 1');

// []
assert(mock.remove(MOCK_SET_ID, 1), '1 removed');

// [0]
mock.add(MOCK_SET_ID, 0);

assert(mock.at(MOCK_SET_ID, 1) == 0, 'set[1] = 0');

// []
mock.remove(MOCK_SET_ID, 0);

expect_out_of_bounds(safe_mock.at(MOCK_SET_ID, 1));
}

73 changes: 61 additions & 12 deletions contracts/src/tests/test_rbac_timelock.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use chainlink::{
IMockMultisigTarget, IMockMultisigTargetDispatcherTrait, IMockMultisigTargetDispatcher
}
};

use openzeppelin::{
introspection::interface::{ISRC5, ISRC5Dispatcher, ISRC5DispatcherTrait, ISRC5_ID},
access::accesscontrol::{
Expand All @@ -22,6 +21,7 @@ use openzeppelin::{
},
token::{erc1155::interface::{IERC1155_RECEIVER_ID}, erc721::interface::{IERC721_RECEIVER_ID}}
};
use chainlink::tests::test_enumerable_set::{expect_out_of_bounds, expect_set_is_1_indexed};
use snforge_std::{
declare, ContractClassTrait, spy_events, EventSpyAssertionsTrait,
start_cheat_caller_address_global, start_cheat_block_timestamp_global
Expand Down Expand Up @@ -590,6 +590,55 @@ fn test_execute_successful() {
assert(actual_toggle, 'toggle true');

assert(timelock.is_operation_done(id), 'operation is done');

// let's try to schedule another batch of operations using the predecessor

let mock_time = 3000;
let mock_ready_time = mock_time + min_delay.try_into().unwrap();

let call_3 = Call {
target: target_address, selector: selector!("flip_toggle"), data: array![].span()
};
let calls = array![call_3].span();
let predecessor = id;
let salt = 2;

start_cheat_caller_address_global(proposer);
start_cheat_block_timestamp_global(mock_time);

timelock.schedule_batch(calls, predecessor, salt, min_delay);

let id = timelock.hash_operation_batch(calls, predecessor, salt);

start_cheat_caller_address_global(executor);
start_cheat_block_timestamp_global(mock_ready_time);

let mut spy = spy_events();

timelock.execute_batch(calls, predecessor, salt);

spy
.assert_emitted(
@array![
(
timelock_address,
RBACTimelock::Event::CallExecuted(
RBACTimelock::CallExecuted {
id: id,
index: 0,
target: call_3.target,
selector: call_3.selector,
data: call_3.data
}
)
)
]
);

let (_, actual_toggle) = target.read();
assert(!actual_toggle, 'toggle went to false again');

assert(timelock.is_operation_done(id), 'operation is done');
}

#[test]
Expand Down Expand Up @@ -761,9 +810,10 @@ fn test_unblock_selector() {
}

#[test]
#[feature("safe_dispatcher")]
fn test_blocked_selector_indexes() {
let (_, admin, _, _, _, _) = deploy_args();
let (_, timelock, _) = setup_timelock();
let (_, timelock, safe_timelock) = setup_timelock();

start_cheat_caller_address_global(admin);

Expand All @@ -775,38 +825,37 @@ fn test_blocked_selector_indexes() {
timelock.block_function_selector(selector2);
timelock.block_function_selector(selector3);

expect_out_of_bounds(safe_timelock.get_blocked_function_selector_at(5));
expect_set_is_1_indexed(safe_timelock.get_blocked_function_selector_at(0));

// [selector1, selector2, selector3]
assert(timelock.get_blocked_function_selector_count() == 3, 'count is 3');
assert(timelock.get_blocked_function_selector_at(1) == selector1, 'selector 1');
assert(timelock.get_blocked_function_selector_at(2) == selector2, 'selector 2');
assert(timelock.get_blocked_function_selector_at(3) == selector3, 'selector 3');
assert(timelock.get_blocked_function_selector_at(0) == 0, 'no selector');

timelock.unblock_function_selector(selector1);

// [selector3, selector2]
assert(timelock.get_blocked_function_selector_count() == 2, 'count is 2');
assert(timelock.get_blocked_function_selector_at(1) == selector3, 'selector 3');
assert(timelock.get_blocked_function_selector_at(2) == selector2, 'selector 2');
assert(timelock.get_blocked_function_selector_at(3) == 0, 'selector 3');
assert(timelock.get_blocked_function_selector_at(0) == 0, 'no selector');
expect_out_of_bounds(safe_timelock.get_blocked_function_selector_at(3));

timelock.unblock_function_selector(selector2);

// [selector3]
assert(timelock.get_blocked_function_selector_count() == 1, 'count is 1');
assert(timelock.get_blocked_function_selector_at(1) == selector3, 'selector 3');
assert(timelock.get_blocked_function_selector_at(2) == 0, 'no selector');
assert(timelock.get_blocked_function_selector_at(3) == 0, 'no selector');
assert(timelock.get_blocked_function_selector_at(0) == 0, 'no selector');
expect_out_of_bounds(safe_timelock.get_blocked_function_selector_at(2));
expect_out_of_bounds(safe_timelock.get_blocked_function_selector_at(3));

timelock.unblock_function_selector(selector3);

assert(timelock.get_blocked_function_selector_count() == 0, 'count is 0');
assert(timelock.get_blocked_function_selector_at(1) == 0, 'no selector');
assert(timelock.get_blocked_function_selector_at(2) == 0, 'no selector');
assert(timelock.get_blocked_function_selector_at(3) == 0, 'no selector');
assert(timelock.get_blocked_function_selector_at(0) == 0, 'no selector');
expect_out_of_bounds(safe_timelock.get_blocked_function_selector_at(1));
expect_out_of_bounds(safe_timelock.get_blocked_function_selector_at(2));
expect_out_of_bounds(safe_timelock.get_blocked_function_selector_at(3));
}

#[test]
Expand Down

0 comments on commit 4922344

Please sign in to comment.