Skip to content

Commit

Permalink
Use actualy dict for dict_squash and dict_copy from segment (kkrt-lab…
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementWalter authored Oct 14, 2024
1 parent 96dcf59 commit e8e1da9
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 240 deletions.
38 changes: 33 additions & 5 deletions cairo/src/account.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ from starkware.cairo.common.find_element import find_element

from src.interfaces.interfaces import ICairo1Helpers
from src.model import model
from src.utils.dict import default_dict_copy
from src.utils.dict import dict_copy, dict_squash
from src.utils.utils import Helpers
from src.utils.bytes import keccak

Expand Down Expand Up @@ -57,15 +57,43 @@ namespace Account {
}

// @dev Copy the Account to safely mutate the storage
// @dev Squash dicts used internally
// @param self The pointer to the Account
func copy{range_check_ptr}(self: model.Account*) -> model.Account* {
alloc_locals;
let (storage_start, storage) = default_dict_copy(self.storage_start, self.storage);
let (transient_storage_start, transient_storage) = default_dict_copy(
let (storage_start, storage) = dict_copy(self.storage_start, self.storage);
let (transient_storage_start, transient_storage) = dict_copy(
self.transient_storage_start, self.transient_storage
);
let (valid_jumpdests_start, valid_jumpdests) = default_dict_copy(
let (valid_jumpdests_start, valid_jumpdests) = dict_copy(
self.valid_jumpdests_start, self.valid_jumpdests
);
return new model.Account(
code_len=self.code_len,
code=self.code,
code_hash=self.code_hash,
storage_start=storage_start,
storage=storage,
transient_storage_start=transient_storage_start,
transient_storage=transient_storage,
valid_jumpdests_start=valid_jumpdests_start,
valid_jumpdests=valid_jumpdests,
nonce=self.nonce,
balance=self.balance,
selfdestruct=self.selfdestruct,
created=self.created,
);
}

// @dev Squash all the internal dicts for soundness
// @dev Squashed dicts are not default_dicts anymore
// @param self The pointer to the Account
func finalize{range_check_ptr}(self: model.Account*) -> model.Account* {
alloc_locals;
let (storage_start, storage) = dict_squash(self.storage_start, self.storage);
let (transient_storage_start, transient_storage) = dict_squash(
self.transient_storage_start, self.transient_storage
);
let (valid_jumpdests_start, valid_jumpdests) = dict_squash(
self.valid_jumpdests_start, self.valid_jumpdests
);
return new model.Account(
Expand Down
45 changes: 33 additions & 12 deletions cairo/src/state.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ from starkware.cairo.common.bool import FALSE, TRUE
from src.account import Account
from src.model import model
from src.gas import Gas
from src.utils.dict import default_dict_copy
from src.utils.dict import dict_copy, dict_squash
from src.utils.utils import Helpers
from src.utils.uint256 import uint256_add, uint256_sub, uint256_eq

Expand All @@ -36,7 +36,7 @@ namespace State {
func copy{range_check_ptr, state: model.State*}() -> model.State* {
alloc_locals;
// accounts are a new memory segment
let (accounts_start, accounts) = default_dict_copy(state.accounts_start, state.accounts);
let (accounts_start, accounts) = dict_copy(state.accounts_start, state.accounts);
// for each account, storage is a new memory segment
Internals._copy_accounts{accounts=accounts}(accounts_start, accounts);

Expand All @@ -61,18 +61,16 @@ namespace State {
func finalize{range_check_ptr, state: model.State*}() {
alloc_locals;
// First squash to get only one account per key
let (local accounts_start, accounts_end) = default_dict_finalize(
state.accounts_start, state.accounts, 0
);
// All the account instances of the same key (address) use the same storage dict
// so it's safe to drop all the intermediate storage dicts.
let (local accounts_start, accounts) = dict_squash(state.accounts_start, state.accounts);

let (local accounts_copy: DictAccess*) = default_dict_new(0);
tempvar accounts_copy_start = accounts_copy;
// Squashes the storage dicts of accounts, and copy the result to a new memory segment.
Internals._copy_accounts{accounts=accounts_copy}(accounts_start, accounts_end);
// Then finalize each account. This also creates a new memory segment for the storage dict.
Internals._finalize_accounts{accounts=accounts}(accounts_start, accounts);

tempvar state = new model.State(
accounts_start=accounts_copy_start,
accounts=accounts_copy,
accounts_start=accounts_start,
accounts=accounts,
events_len=state.events_len,
events=state.events,
transfers_len=state.transfers_len,
Expand Down Expand Up @@ -393,7 +391,6 @@ namespace State {

namespace Internals {
// @notice Iterate through the accounts dict and copy them
// @dev Should be applied on a squashed dict
// @param accounts_start The dict start pointer
// @param accounts_end The dict end pointer
func _copy_accounts{range_check_ptr, accounts: DictAccess*}(
Expand All @@ -417,6 +414,30 @@ namespace Internals {
return _copy_accounts(accounts_start + DictAccess.SIZE, accounts_end);
}

// @notice Iterate through the accounts dict and finalize them
// @param accounts_start The dict start pointer
// @param accounts_end The dict end pointer
func _finalize_accounts{range_check_ptr, accounts: DictAccess*}(
accounts_start: DictAccess*, accounts_end: DictAccess*
) {
if (accounts_start == accounts_end) {
return ();
}

if (accounts_start.new_value == 0) {
// If we do a dict_read on an unexisting account, `prev_value` and `new_value` are set to 0.
// However we expected pointers to model.Account, and casting 0 to model.Account* will
// cause a "Memory address must be relocatable" error.
return _finalize_accounts(accounts_start + DictAccess.SIZE, accounts_end);
}

let account = cast(accounts_start.new_value, model.Account*);
let account = Account.finalize(account);
dict_write{dict_ptr=accounts}(key=accounts_start.key, new_value=cast(account, felt));

return _finalize_accounts(accounts_start + DictAccess.SIZE, accounts_end);
}

func _cache_precompile{pedersen_ptr: HashBuiltin*, range_check_ptr, accounts_ptr: DictAccess*}(
evm_address: felt
) {
Expand Down
153 changes: 31 additions & 122 deletions cairo/src/utils/dict.cairo
Original file line number Diff line number Diff line change
@@ -1,134 +1,43 @@
from starkware.cairo.common.dict_access import DictAccess
from starkware.cairo.common.default_dict import default_dict_new
from starkware.cairo.common.dict import dict_write, dict_squash
from starkware.cairo.common.math_cmp import is_not_zero
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.dict_access import DictAccess
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.common.squash_dict import squash_dict
from starkware.cairo.common.uint256 import Uint256

from src.utils.maths import unsigned_div_rem

func dict_keys{range_check_ptr}(dict_start: DictAccess*, dict_end: DictAccess*) -> (
keys_len: felt, keys: felt*
) {
alloc_locals;
let (local keys_start: felt*) = alloc();
let dict_len = dict_end - dict_start;
let (local keys_len, _) = unsigned_div_rem(dict_len, DictAccess.SIZE);
local range_check_ptr = range_check_ptr;

if (dict_len == 0) {
return (keys_len, keys_start);
}

tempvar keys = keys_start;
tempvar len = keys_len;
tempvar dict = dict_start;

loop:
let keys = cast([ap - 3], felt*);
let len = [ap - 2];
let dict = cast([ap - 1], DictAccess*);

assert [keys] = dict.key;
tempvar keys = keys + 1;
tempvar len = len - 1;
tempvar dict = dict + DictAccess.SIZE;

static_assert keys == [ap - 3];
static_assert len == [ap - 2];
static_assert dict == [ap - 1];

jmp loop if len != 0;

return (keys_len, keys_start);
}

func dict_values{range_check_ptr}(dict_start: DictAccess*, dict_end: DictAccess*) -> (
values_len: felt, values: Uint256*
func dict_copy{range_check_ptr}(dict_start: DictAccess*, dict_end: DictAccess*) -> (
DictAccess*, DictAccess*
) {
alloc_locals;
let (local values: Uint256*) = alloc();
let dict_len = dict_end - dict_start;
let (local values_len, _) = unsigned_div_rem(dict_len, DictAccess.SIZE);
local range_check_ptr = range_check_ptr;

if (dict_len == 0) {
return (values_len, values);
}

tempvar index = 0;
tempvar len = values_len;
tempvar dict = dict_start;

loop:
let index = [ap - 3];
let len = [ap - 2];
let dict = cast([ap - 1], DictAccess*);

let pointer = cast(dict.new_value, Uint256*);
assert values[index] = pointer[0];

tempvar index = index + 1;
tempvar len = len - 1;
tempvar dict = dict + DictAccess.SIZE;

static_assert index == [ap - 3];
static_assert len == [ap - 2];
static_assert dict == [ap - 1];

jmp loop if len != 0;

return (values_len, values);
let (local new_start: DictAccess*) = alloc();
let new_end = new_start + (dict_end - dict_start);
memcpy(new_start, dict_start, dict_end - dict_start);
// Register the segment as a dict in the DictManager.
%{ dict_copy %}
return (new_start, new_end);
}

func default_dict_copy{range_check_ptr}(start: DictAccess*, end: DictAccess*) -> (
DictAccess*, DictAccess*
) {
// @dev Copied from the standard library with an updated dict_new() implementation.
func dict_squash{range_check_ptr}(
dict_accesses_start: DictAccess*, dict_accesses_end: DictAccess*
) -> (squashed_dict_start: DictAccess*, squashed_dict_end: DictAccess*) {
alloc_locals;
let (squashed_start, squashed_end) = dict_squash(start, end);
local range_check_ptr = range_check_ptr;
let dict_len = squashed_end - squashed_start;

local default_value;
if (dict_len == 0) {
assert default_value = 0;
} else {
assert default_value = squashed_start.prev_value;
}

let (local new_start) = default_dict_new(default_value);
let new_ptr = new_start;

if (dict_len == 0) {
return (new_start, new_ptr);
}

tempvar squashed_start = squashed_start;
tempvar dict_len = dict_len;
tempvar new_ptr = new_ptr;

loop:
let squashed_start = cast([ap - 3], DictAccess*);
let dict_len = [ap - 2];
let new_ptr = cast([ap - 1], DictAccess*);
let default_value = [fp + 1];

let key = [squashed_start].key;
let prev_value = [squashed_start].prev_value;
assert prev_value = default_value;
let new_value = [squashed_start].new_value;

dict_write{dict_ptr=new_ptr}(key=key, new_value=new_value);

tempvar squashed_start = squashed_start + DictAccess.SIZE;
tempvar dict_len = dict_len - DictAccess.SIZE;
tempvar new_ptr = new_ptr;

static_assert squashed_start == [ap - 3];
static_assert dict_len == [ap - 2];
static_assert new_ptr == [ap - 1];

jmp loop if dict_len != 0;

return (new_start, new_ptr);
%{ dict_squash %}
ap += 1;
let squashed_dict_start = cast([ap - 1], DictAccess*);

let (squashed_dict_end) = squash_dict(
dict_accesses=dict_accesses_start,
dict_accesses_end=dict_accesses_end,
squashed_dict=squashed_dict_start,
);

%{
# Update the DictTracker's current_ptr to point to the end of the squashed dict.
__dict_manager.get_tracker(ids.squashed_dict_start).current_ptr = \
ids.squashed_dict_end.address_
%}
return (squashed_dict_start=squashed_dict_start, squashed_dict_end=squashed_dict_end);
}
Loading

0 comments on commit e8e1da9

Please sign in to comment.