From 55a26f4228dcff9c51d3d6bfd5e6638bafe63101 Mon Sep 17 00:00:00 2001 From: Lior Goldberg Date: Tue, 4 Oct 2022 17:32:00 +0300 Subject: [PATCH 1/4] Cairo v0.10.1. --- Dockerfile | 2 +- README.md | 4 +- scripts/requirements-deps.json | 6 +- .../everest/business_logic/CMakeLists.txt | 3 +- .../business_logic/internal_transaction.py | 21 +- src/services/everest/business_logic/state.py | 8 +- .../transaction_execution_objects.py | 20 + src/starkware/cairo/common/CMakeLists.txt | 1 + .../cairo/common/cairo_function_runner.py | 58 ++- src/starkware/cairo/common/structs.py | 54 ++- src/starkware/cairo/common/uint256.cairo | 58 +++ src/starkware/cairo/lang/VERSION | 2 +- .../cairo/lang/builtins/CMakeLists.txt | 8 +- .../cairo/lang/compiler/CMakeLists.txt | 2 +- src/starkware/cairo/lang/compiler/parser.py | 2 +- .../cairo/lang/compiler/parser_transformer.py | 2 +- .../preprocessor/auxiliary_info_collector.py | 9 +- .../compiler/preprocessor/preprocessor.py | 11 +- .../cairo/lang/ide/vscode-cairo/package.json | 2 +- .../cairo/lang/migrators/migrator.py | 8 +- src/starkware/cairo/lang/vm/CMakeLists.txt | 1 + src/starkware/cairo/lang/vm/cairo_runner.py | 4 +- src/starkware/eth/eth_test_utils.py | 16 +- src/starkware/python/async_subprocess.py | 39 +- src/starkware/python/utils.py | 34 +- src/starkware/python/utils_test.py | 35 +- .../business_logic/execution/CMakeLists.txt | 2 +- .../execution/execute_entry_point.py | 3 +- .../execution/execute_entry_point_base.py | 4 +- .../business_logic/execution/gas_usage.py | 20 +- .../business_logic/execution/objects.py | 67 ++- .../execution/os_resources.json | 14 +- .../business_logic/execution/os_usage.py | 4 +- .../fact_state/patricia_state.py | 46 +- .../business_logic/fact_state/state.py | 39 +- .../starknet/business_logic/state/state.py | 161 ++++++- .../business_logic/state/state_api.py | 25 ++ .../business_logic/state/state_api_objects.py | 2 +- .../business_logic/transaction/objects.py | 418 +++++++++++++++--- .../transaction/state_objects.py | 87 +++- .../starknet/business_logic/utils.py | 20 +- src/starkware/starknet/cli/CMakeLists.txt | 1 - src/starkware/starknet/cli/starknet_cli.py | 225 ++++++---- src/starkware/starknet/common/constants.cairo | 1 + .../starknet/common/eth_utils_test.py | 3 +- src/starkware/starknet/common/storage_test.py | 10 +- .../starknet/compiler/validation_utils.py | 80 +++- .../compiler/validation_utils_test.py | 181 +++++--- .../starknet/core/os/contracts.cairo | 2 +- .../starknet/core/os/program_hash.json | 2 +- .../starknet/core/os/syscall_utils.py | 44 +- .../os/transaction_hash/transaction_hash.py | 25 ++ .../transaction_hash/transaction_hash_test.py | 53 ++- .../starknet/core/os/transactions.cairo | 154 ++++++- .../core/test_contract/dummy_account.cairo | 5 + .../starknet/definitions/CMakeLists.txt | 2 + .../starknet/definitions/error_codes.py | 1 + src/starkware/starknet/definitions/fields.py | 21 +- .../starknet/definitions/general_config.yml | 4 +- .../starknet/definitions/transaction_type.py | 1 + src/starkware/starknet/public/abi.py | 7 +- .../starknet/security/CMakeLists.txt | 1 + .../starknet/security/starknet_common.cairo | 1 + .../cairo_sha256_arbitrary_input_length.json | 20 + .../starknet/security/whitelists/latest.json | 16 + .../starknet/services/api/CMakeLists.txt | 1 + .../api/feeder_gateway/response_objects.py | 109 ++++- .../services/api/gateway/transaction.py | 49 +- .../starknet/services/api/messages.py | 3 +- .../services/utils/sequencer_api_utils.py | 17 +- .../starknet/storage/starknet_storage.py | 18 + .../testing/MockStarknetMessaging.sol | 2 +- src/starkware/starknet/testing/starknet.py | 6 +- .../starknet/testing/starknet_test.py | 29 +- .../third_party/open_zeppelin/Account.cairo | 9 + src/starkware/starknet/wallets/CMakeLists.txt | 3 +- src/starkware/starknet/wallets/account.py | 52 ++- .../starknet/wallets/open_zeppelin.py | 269 +++++++---- .../starkware_utils/error_handling.py | 7 +- .../marshmallow_dataclass_fields.py | 7 +- 80 files changed, 2125 insertions(+), 638 deletions(-) create mode 100644 src/starkware/starknet/security/whitelists/cairo_sha256_arbitrary_input_length.json diff --git a/Dockerfile b/Dockerfile index 66c63f97..da2ae455 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,7 +9,7 @@ RUN pip install cmake==3.22 RUN curl https://binaries.soliditylang.org/linux-amd64/solc-linux-amd64-v0.6.12+commit.27d51765 -o /usr/local/bin/solc-0.6.12 RUN echo 'f6cb519b01dabc61cab4c184a3db11aa591d18151e362fcae850e42cffdfb09a /usr/local/bin/solc-0.6.12' | sha256sum --check RUN chmod +x /usr/local/bin/solc-0.6.12 -RUN npm install -g --unsafe-perm ganache-cli@6.12.2 +RUN npm install -g --unsafe-perm ganache@7.4.3 COPY . /app/ diff --git a/README.md b/README.md index 415ec225..4d3fd005 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ We recommend starting from [Setting up the environment](https://cairo-lang.org/d # Installation instructions You should be able to download the python package zip file directly from -[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.10.0) +[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.10.1) and install it using ``pip``. See [Setting up the environment](https://cairo-lang.org/docs/quickstart.html). @@ -54,7 +54,7 @@ Once the docker image is built, you can fetch the python package zip file using: ```bash > container_id=$(docker create cairo) -> docker cp ${container_id}:/app/cairo-lang-0.10.0.zip . +> docker cp ${container_id}:/app/cairo-lang-0.10.1.zip . > docker rm -v ${container_id} ``` diff --git a/scripts/requirements-deps.json b/scripts/requirements-deps.json index 14b72073..eb6fc7a7 100644 --- a/scripts/requirements-deps.json +++ b/scripts/requirements-deps.json @@ -472,9 +472,9 @@ { "dependencies": [], "package": { - "installed_version": "0.12.0", - "key": "lark-parser", - "package_name": "lark-parser" + "installed_version": "1.1.2", + "key": "lark", + "package_name": "lark" } }, { diff --git a/src/services/everest/business_logic/CMakeLists.txt b/src/services/everest/business_logic/CMakeLists.txt index e4f7d523..f0ee659d 100644 --- a/src/services/everest/business_logic/CMakeLists.txt +++ b/src/services/everest/business_logic/CMakeLists.txt @@ -29,11 +29,10 @@ python_lib(everest_internal_transaction_lib LIBS everest_business_logic_lib everest_business_logic_state_api_lib + everest_transaction_execution_objects_lib everest_transaction_lib starkware_config_utils_lib - starkware_dataclasses_utils_lib starkware_one_of_schema_utils_lib - pip_marshmallow_dataclass ) python_lib(everest_transaction_execution_objects_lib diff --git a/src/services/everest/business_logic/internal_transaction.py b/src/services/everest/business_logic/internal_transaction.py index 057246b4..1f175f26 100644 --- a/src/services/everest/business_logic/internal_transaction.py +++ b/src/services/everest/business_logic/internal_transaction.py @@ -3,29 +3,14 @@ from abc import abstractmethod from typing import Iterable, Iterator, Optional, Type -import marshmallow_dataclass - from services.everest.api.gateway.transaction import EverestTransaction from services.everest.business_logic.state import StateSelectorBase from services.everest.business_logic.state_api import StateProxy +from services.everest.business_logic.transaction_execution_objects import ( + EverestTransactionExecutionInfo, +) from starkware.starkware_utils.config_base import Config from starkware.starkware_utils.one_of_schema_tracker import SubclassSchemaTracker -from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass - - -class EverestTransactionExecutionInfo(ValidatedMarshmallowDataclass): - """ - Base class of classes containing information generated from an execution of a transaction on - the state. Each Everest application may implement it specifically. - Note that this object will only be relevant if the transaction executed successfully. - """ - - -@marshmallow_dataclass.dataclass(frozen=True) -class TransactionExecutionInfo(EverestTransactionExecutionInfo): - """ - A non-abstract derived class for completeness of AggregatedScope. Used by StarkEx and Perpetual. - """ class EverestInternalStateTransaction(SubclassSchemaTracker): diff --git a/src/services/everest/business_logic/state.py b/src/services/everest/business_logic/state.py index 2277c5db..5c6b6b53 100644 --- a/src/services/everest/business_logic/state.py +++ b/src/services/everest/business_logic/state.py @@ -226,8 +226,12 @@ async def commit( """ @classmethod - def squash_many(cls: Type[TStateDiff], state_diffs: Iterable[TStateDiff]) -> TStateDiff: + def squash_many( + cls: Type[TStateDiff], + state_diffs: Iterable[TStateDiff], + initial_state_diff: TStateDiff, + ) -> TStateDiff: """ Creates a state diff. object with the given changes applied in chronological order. """ - return functools.reduce(lambda x, y: x.squash(other=y), state_diffs) + return functools.reduce(lambda x, y: x.squash(other=y), state_diffs, initial_state_diff) diff --git a/src/services/everest/business_logic/transaction_execution_objects.py b/src/services/everest/business_logic/transaction_execution_objects.py index b4e5cf58..8300c910 100644 --- a/src/services/everest/business_logic/transaction_execution_objects.py +++ b/src/services/everest/business_logic/transaction_execution_objects.py @@ -3,9 +3,25 @@ import marshmallow import marshmallow_dataclass +from starkware.starkware_utils.error_handling import StarkException from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass +class EverestTransactionExecutionInfo(ValidatedMarshmallowDataclass): + """ + Base class of classes containing information generated from an execution of a transaction on + the state. Each Everest application may implement it specifically. + Note that this object will only be relevant if the transaction executed successfully. + """ + + +@marshmallow_dataclass.dataclass(frozen=True) +class TransactionExecutionInfo(EverestTransactionExecutionInfo): + """ + A non-abstract derived class for completeness of AggregatedScope. Used by StarkEx and Perpetual. + """ + + @marshmallow_dataclass.dataclass(frozen=True) class TransactionFailureReason(ValidatedMarshmallowDataclass): """ @@ -30,3 +46,7 @@ def truncate_error_message(self, data: Dict[str, Any], many: bool, **kwargs) -> data["error_message"] = error_message[:5000] return data + + @classmethod + def from_exception(cls, exception: StarkException) -> "TransactionFailureReason": + return cls(code=exception.code.name, error_message=exception.message) diff --git a/src/starkware/cairo/common/CMakeLists.txt b/src/starkware/cairo/common/CMakeLists.txt index dd8714e8..dd315ba1 100644 --- a/src/starkware/cairo/common/CMakeLists.txt +++ b/src/starkware/cairo/common/CMakeLists.txt @@ -74,6 +74,7 @@ python_lib(cairo_common_validate_utils_lib cairo_run_builtins_lib cairo_run_lib cairo_vm_lib + starkware_python_utils_lib ) python_lib(cairo_function_runner_lib diff --git a/src/starkware/cairo/common/cairo_function_runner.py b/src/starkware/cairo/common/cairo_function_runner.py index 906396cc..30e2b6bc 100644 --- a/src/starkware/cairo/common/cairo_function_runner.py +++ b/src/starkware/cairo/common/cairo_function_runner.py @@ -1,5 +1,5 @@ from collections.abc import Iterable -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Dict, Optional, Tuple, Union, cast from starkware.cairo.common.structs import CairoStructFactory from starkware.cairo.lang.builtins.bitwise.bitwise_builtin_runner import BitwiseBuiltinRunner @@ -23,6 +23,7 @@ from starkware.cairo.lang.vm.security import verify_secure_runner from starkware.cairo.lang.vm.utils import RunResources from starkware.cairo.lang.vm.vm_exceptions import SecurityError, VmException +from starkware.python.utils import safe_zip class CairoFunctionRunner(CairoRunner): @@ -130,26 +131,33 @@ def run( trace_on_failure: bool = False, apply_modulo_to_args: Optional[bool] = None, use_full_name: bool = False, + verify_implicit_args_segment: bool = False, **kwargs, - ): + ) -> Tuple[Tuple[MaybeRelocatable, ...], Tuple[MaybeRelocatable, ...]]: """ Runs func_name(*args). args are converted to Cairo-friendly ones using gen_arg. + Returns the return values of the function, splitted into 2 tuples of implicit values and + explicit values. Structs will be flattened to a sequence of felts as part of the returned + tuple. + Additional params: verify_secure - Run verify_secure_runner to do extra verifications. trace_on_failure - Run the tracer in case of failure to help debugging. apply_modulo_to_args - Apply modulo operation on integer arguments. use_full_name - Treat 'func_name' as a fully qualified identifier name, rather than a - relative one. + relative one. + verify_implicit_args_segment - For each implicit argument, verify that the argument and the + return value are in the same segment. """ assert isinstance(self.program, Program) entrypoint = self.program.get_label(func_name, full_name_lookup=use_full_name) structs_factory = CairoStructFactory.from_program(program=self.program) - full_args_struct = structs_factory.build_func_args( - func=ScopedName.from_string(scope=func_name) - ) + func = ScopedName.from_string(scope=func_name) + + full_args_struct = structs_factory.build_func_args(func=func) all_args = full_args_struct(*args, **kwargs) try: @@ -173,6 +181,44 @@ def run( trace_runner(runner=self) raise + # The number of implicit arguments is identical to the number of implicit return values. + n_implicit_ret_vals = structs_factory.get_implicit_args_length(func=func) + n_explicit_ret_vals = structs_factory.get_explicit_return_values_length(func=func) + n_ret_vals = n_explicit_ret_vals + n_implicit_ret_vals + implicit_retvals = tuple( + self.vm_memory.get_range( + addr=self.vm.run_context.ap - n_ret_vals, size=n_implicit_ret_vals + ) + ) + + explicit_retvals = tuple( + self.vm_memory.get_range( + addr=self.vm.run_context.ap - n_explicit_ret_vals, size=n_explicit_ret_vals + ) + ) + + # Verify the memory segments of the implicit arguments. + if verify_implicit_args_segment: + implicit_args = all_args[:n_implicit_ret_vals] + for implicit_arg, implicit_retval in safe_zip(implicit_args, implicit_retvals): + assert isinstance( + implicit_arg, RelocatableValue + ), f"Implicit arguments must be RelocatableValues, {implicit_arg} is not." + assert isinstance(implicit_retval, RelocatableValue), ( + f"Argument {implicit_arg} is a RelocatableValue, but the returned value " + f"{implicit_retval} is not." + ) + assert implicit_arg.segment_index == implicit_retval.segment_index, ( + f"Implicit argument {implicit_arg} is not on the same segment as the returned " + f"{implicit_retval}." + ) + assert implicit_retval.offset >= implicit_arg.offset, ( + f"The offset of the returned implicit argument {implicit_retval} is less than " + f"the offset of the input {implicit_arg}." + ) + + return implicit_retvals, explicit_retvals + def run_from_entrypoint( self, entrypoint: Union[str, int], diff --git a/src/starkware/cairo/common/structs.py b/src/starkware/cairo/common/structs.py index dc3359f9..dfca5fcd 100644 --- a/src/starkware/cairo/common/structs.py +++ b/src/starkware/cairo/common/structs.py @@ -1,9 +1,20 @@ from typing import List, MutableMapping, NamedTuple, Optional +from starkware.cairo.lang.compiler.ast.cairo_types import ( + CairoType, + TypeCodeoffset, + TypeFelt, + TypePointer, + TypeStruct, + TypeTuple, +) from starkware.cairo.lang.compiler.ast.code_elements import CodeElementFunction from starkware.cairo.lang.compiler.identifier_definition import StructDefinition from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager -from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition +from starkware.cairo.lang.compiler.identifier_utils import ( + get_struct_definition, + get_type_definition, +) from starkware.cairo.lang.compiler.program import Program from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.python.utils import WriteOnceDict @@ -91,6 +102,47 @@ def build_func_args(self, func: ScopedName): return NamedTuple(f"{func[-1:]}_full_args", typed_fields) + def size_of(self, typ: CairoType) -> int: + """ + Returns the total size (in felts) of the given type. Pointer types count as one felt. + """ + size = 0 + if isinstance(typ, TypeStruct): + struct_def = get_struct_definition( + struct_name=typ.scope, identifier_manager=self.identifiers + ) + size = struct_def.size + elif isinstance(typ, TypeTuple): + for item in typ.members: + size += self.size_of(typ=item.typ) + else: + assert isinstance( + typ, (TypeFelt, TypeCodeoffset, TypePointer) + ), f"Unsupported Cairo type {typ}." + size = 1 + + return size + + def get_explicit_return_values_length(self, func: ScopedName) -> int: + """ + Returns the length of the explicit return values of a function + """ + full_name = self._get_full_name(func) + type_def = get_type_definition( + full_name + CodeElementFunction.RETURN_SCOPE, self.identifiers + ) + return self.size_of(typ=type_def.cairo_type) + + def get_implicit_args_length(self, func: ScopedName) -> int: + """ + Returns the length of the implicit arguments of a function + """ + full_name = self._get_full_name(func) + struct_def = get_struct_definition( + full_name + CodeElementFunction.IMPLICIT_ARGUMENT_SCOPE, self.identifiers + ) + return struct_def.size + @property def structs(self): """ diff --git a/src/starkware/cairo/common/uint256.cairo b/src/starkware/cairo/common/uint256.cairo index 8599cb89..5c8a4e72 100644 --- a/src/starkware/cairo/common/uint256.cairo +++ b/src/starkware/cairo/common/uint256.cairo @@ -222,6 +222,64 @@ func uint256_unsigned_div_rem{range_check_ptr}(a: Uint256, div: Uint256) -> ( return (quotient=quotient, remainder=remainder); } +// Computes: +// 1. The integer division `(a * b) // div` (as a 512-bit number). +// 2. The remainder `(a * b) modulo div`. +// Assumption: div != 0. +func uint256_mul_div_mod{range_check_ptr}(a: Uint256, b: Uint256, div: Uint256) -> ( + quotient_low: Uint256, quotient_high: Uint256, remainder: Uint256 +) { + alloc_locals; + + // Compute a * b (512 bits). + let (ab_low, ab_high) = uint256_mul(a, b); + + // Guess the quotient and remainder of (a * b) / d. + local quotient_low: Uint256; + local quotient_high: Uint256; + local remainder: Uint256; + + %{ + a = (ids.a.high << 128) + ids.a.low + b = (ids.b.high << 128) + ids.b.low + div = (ids.div.high << 128) + ids.div.low + quotient, remainder = divmod(a * b, div) + + ids.quotient_low.low = quotient & ((1 << 128) - 1) + ids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1) + ids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1) + ids.quotient_high.high = quotient >> 384 + ids.remainder.low = remainder & ((1 << 128) - 1) + ids.remainder.high = remainder >> 128 + %} + + // Compute x = quotient * div + remainder. + uint256_check(quotient_high); + let (quotient_mod10, quotient_mod11) = uint256_mul(quotient_high, div); + uint256_check(quotient_low); + let (quotient_mod00, quotient_mod01) = uint256_mul(quotient_low, div); + // Since x should equal a * b, the high 256 bits must be zero. + assert quotient_mod11 = Uint256(0, 0); + + // The low 256 bits of x must be ab_low. + uint256_check(remainder); + let (x0, carry0) = uint256_add(quotient_mod00, remainder); + assert x0 = ab_low; + + let (x1, carry1) = uint256_add(quotient_mod01, quotient_mod10); + assert carry1 = 0; + let (x1, carry2) = uint256_add(x1, Uint256(low=carry0, high=0)); + assert carry2 = 0; + + assert x1 = ab_high; + + // Verify that 0 <= remainder < div. + let (is_valid) = uint256_lt(remainder, div); + assert is_valid = 1; + + return (quotient_low=quotient_low, quotient_high=quotient_high, remainder=remainder); +} + // Returns the bitwise NOT of an integer. func uint256_not{range_check_ptr}(a: Uint256) -> (res: Uint256) { return (res=Uint256(low=ALL_ONES - a.low, high=ALL_ONES - a.high)); diff --git a/src/starkware/cairo/lang/VERSION b/src/starkware/cairo/lang/VERSION index 78bc1abd..57121573 100644 --- a/src/starkware/cairo/lang/VERSION +++ b/src/starkware/cairo/lang/VERSION @@ -1 +1 @@ -0.10.0 +0.10.1 diff --git a/src/starkware/cairo/lang/builtins/CMakeLists.txt b/src/starkware/cairo/lang/builtins/CMakeLists.txt index 5d9e1836..aff20e74 100644 --- a/src/starkware/cairo/lang/builtins/CMakeLists.txt +++ b/src/starkware/cairo/lang/builtins/CMakeLists.txt @@ -6,7 +6,6 @@ python_lib(cairo_run_builtins_lib PREFIX starkware/cairo/lang/builtins FILES - all_builtins.py bitwise/bitwise_builtin_runner.py bitwise/instance_def.py ec/ec_op_builtin_runner.py @@ -27,6 +26,13 @@ python_lib(cairo_run_builtins_lib starkware_python_utils_lib ) +python_lib(cairo_all_builtins_lib + PREFIX starkware/cairo/lang/builtins + + FILES + all_builtins.py +) + python_lib(cairo_run_builtins_test_utils_lib PREFIX starkware/cairo/lang/builtins FILES diff --git a/src/starkware/cairo/lang/compiler/CMakeLists.txt b/src/starkware/cairo/lang/compiler/CMakeLists.txt index ea09cd6d..189c865c 100644 --- a/src/starkware/cairo/lang/compiler/CMakeLists.txt +++ b/src/starkware/cairo/lang/compiler/CMakeLists.txt @@ -94,7 +94,7 @@ python_lib(cairo_compile_lib pip_marshmallow_enum pip_marshmallow_oneofschema pip_marshmallow - pip_lark_parser + pip_lark ) python_exe(cairo_compile_exe diff --git a/src/starkware/cairo/lang/compiler/parser.py b/src/starkware/cairo/lang/compiler/parser.py index 49a3fb3d..678e6a79 100644 --- a/src/starkware/cairo/lang/compiler/parser.py +++ b/src/starkware/cairo/lang/compiler/parser.py @@ -234,7 +234,7 @@ def parse( def lex(code: str) -> List[lark.lexer.Token]: """ - Runs the lexer on the given code and returns the lark-parser tokens. + Runs the lexer on the given code and returns the lark tokens. """ return list(GRAMMAR_PARSER.lex(code)) diff --git a/src/starkware/cairo/lang/compiler/parser_transformer.py b/src/starkware/cairo/lang/compiler/parser_transformer.py index 851f9ec4..771877b0 100644 --- a/src/starkware/cairo/lang/compiler/parser_transformer.py +++ b/src/starkware/cairo/lang/compiler/parser_transformer.py @@ -628,7 +628,7 @@ def code_element_empty_line(self, value): @v_args(meta=True) def commented_code_element(self, meta, value): - comment = value[1][2:] if len(value) == 2 else None + comment = value[1][2:] if value[1] is not None else None return CommentedCodeElement( code_elm=value[0], comment=comment, location=self.meta2loc(meta) ) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/auxiliary_info_collector.py b/src/starkware/cairo/lang/compiler/preprocessor/auxiliary_info_collector.py index 30a3aaac..6785105f 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/auxiliary_info_collector.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/auxiliary_info_collector.py @@ -29,8 +29,7 @@ def start_function_info( start_pc: int, implicit_args_struct: StructDefinition, args_struct: StructDefinition, - ret_names: List[str], - ret_types: List[CairoType], + ret_types: Optional[CairoType], ): pass @@ -38,6 +37,10 @@ def start_function_info( def finish_function_info(self, end_pc: int, total_ap_change: RegChange): pass + @abstractmethod + def start_function_retry(self): + pass + @abstractmethod def add_assert_eq(self, lhs: Expression, rhs: Expression): pass @@ -98,7 +101,7 @@ def start_return(self): pass @abstractmethod - def finish_return(self, exprs: List[Expression]): + def finish_return(self, expr: Expression): pass @abstractmethod diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py index bbac267f..f8144b7e 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py @@ -667,8 +667,7 @@ def visit_CodeElementFunction(self, elm: CodeElementFunction): start_pc=self.current_pc, implicit_args_struct=implicit_args_struct, args_struct=args_struct, - ret_names=[] if elm.returns is None else ["ret"], - ret_types=[] if elm.returns is None else [self.resolve_type(elm.returns)], + ret_type=None if elm.returns is None else self.resolve_type(elm.returns), ) # Process code_elements. @@ -723,10 +722,12 @@ def visit_function_body_with_retries(self, code_block: CodeBlock, location: Opti default_location=location, ) + if self.auxiliary_info is not None: + self.auxiliary_info.start_function_retry() + # These cases cannot be fixed for reference revocations: # * Functions without alloc_locals. - # * Contexts with self.auxiliary_info. - if not has_alloc_locals or self.auxiliary_info is not None: + if not has_alloc_locals: self.visit_uncommented_code_block(code_elements) return @@ -1578,7 +1579,7 @@ def visit_CodeElementReturn(self, elm: CodeElementReturn): self.visit(code_elm_ret) if self.auxiliary_info is not None: - self.auxiliary_info.finish_return(exprs=[elm.expr]) + self.auxiliary_info.finish_return(expr=elm.expr) def check_tail_call_cast( self, diff --git a/src/starkware/cairo/lang/ide/vscode-cairo/package.json b/src/starkware/cairo/lang/ide/vscode-cairo/package.json index ba045f60..0f3d2af5 100644 --- a/src/starkware/cairo/lang/ide/vscode-cairo/package.json +++ b/src/starkware/cairo/lang/ide/vscode-cairo/package.json @@ -2,7 +2,7 @@ "name": "cairo", "displayName": "Cairo", "description": "Support Cairo syntax", - "version": "0.10.0", + "version": "0.10.1", "engines": { "vscode": "^1.30.0" }, diff --git a/src/starkware/cairo/lang/migrators/migrator.py b/src/starkware/cairo/lang/migrators/migrator.py index 4be2f772..07c967a2 100644 --- a/src/starkware/cairo/lang/migrators/migrator.py +++ b/src/starkware/cairo/lang/migrators/migrator.py @@ -94,18 +94,18 @@ def code_element_function(self, value): ) @lark.v_args(meta=True) - def commented_code_element(self, value, meta): - comment = value[1][1:] if len(value) == 2 else None + def commented_code_element(self, meta, value): + comment = value[1][1:] if value[1] is not None else None return CommentedCodeElement( code_elm=value[0], comment=comment, location=self.meta2loc(meta) ) @lark.v_args(meta=True) - def code_element_return(self, value, meta): + def code_element_return(self, meta, value): (expr,) = value if not isinstance(expr, ExprParentheses): - return super().code_element_return(value, meta) + return super().code_element_return(meta, value) # Replace the outer parentheses with an ExprTuple with has_trailing_comma=True. location = self.meta2loc(meta) diff --git a/src/starkware/cairo/lang/vm/CMakeLists.txt b/src/starkware/cairo/lang/vm/CMakeLists.txt index e5894ed4..2e7de166 100644 --- a/src/starkware/cairo/lang/vm/CMakeLists.txt +++ b/src/starkware/cairo/lang/vm/CMakeLists.txt @@ -43,6 +43,7 @@ python_lib(cairo_vm_lib vm_exceptions.py LIBS + cairo_all_builtins_lib cairo_compile_lib cairo_relocatable_lib cairo_vm_crypto_lib diff --git a/src/starkware/cairo/lang/vm/cairo_runner.py b/src/starkware/cairo/lang/vm/cairo_runner.py index f3c106e6..830db408 100644 --- a/src/starkware/cairo/lang/vm/cairo_runner.py +++ b/src/starkware/cairo/lang/vm/cairo_runner.py @@ -251,8 +251,10 @@ def initialize_state(self, entrypoint: Union[str, int], stack: Sequence[MaybeRel self.load_data(self.execution_base, stack) def initialize_vm( - self, hint_locals, static_locals: Optional[Dict[str, Any]] = None, vm_class=VirtualMachine + self, hint_locals, static_locals: Optional[Dict[str, Any]] = None, vm_class=None ): + if vm_class is None: + vm_class = VirtualMachine context = RunContext( pc=self.initial_pc, ap=self.initial_ap, diff --git a/src/starkware/eth/eth_test_utils.py b/src/starkware/eth/eth_test_utils.py index fedb7195..2eedf2eb 100644 --- a/src/starkware/eth/eth_test_utils.py +++ b/src/starkware/eth/eth_test_utils.py @@ -51,7 +51,7 @@ def context_manager(cls): def advance_time(self, n_seconds: int): self.w3.provider.make_request( - method=web3_types.RPCEndpoint("evm_increaseTime"), params=n_seconds + method=web3_types.RPCEndpoint("evm_increaseTime"), params=[n_seconds] ) self.w3.provider.make_request(method=web3_types.RPCEndpoint("evm_mine"), params=[]) @@ -61,6 +61,12 @@ def get_block_by_hash(self, block_hash: str) -> "EthBlock": def get_balance(self, address: str) -> int: return self.w3.eth.getBalance(address) + def set_account_balance(self, address: str, balance: int): + assert balance >= 0, "Cannot set a negative balance." + self.w3.provider.make_request( + method=web3_types.RPCEndpoint("evm_setAccountBalance"), params=[address, balance] + ) + class Ganache: """ @@ -74,8 +80,8 @@ def __init__(self): """ self.port = random.randrange(1024, 8192) self.ganache_proc = subprocess.Popen( - f"ganache-cli -p {self.port} --chainId 32 --networkId 32 --gasLimit 8000000 " - "--allow-unlimited-contract-size", + f"ganache -p {self.port} --chain.chainId 32 --chain.networkId 32 " + "--miner.blockGasLimit 8000000 --chain.allow-unlimited-contract-size", shell=True, stdout=subprocess.DEVNULL, # Open the process in a new process group. @@ -226,7 +232,7 @@ def transact(self, *args, transact_args: Optional[Dict[str, Any]] = None) -> "Et tx_hash = self._func(*args).transact(transact_args) w3_tx_receipt = self.contract.w3.eth.wait_for_transaction_receipt(tx_hash) return EthReceipt(contract=self.contract, w3_tx_receipt=w3_tx_receipt) - except web3.exceptions.ContractLogicError as ex: + except (web3.exceptions.ContractLogicError, ValueError) as ex: raise EthRevertException(str(ex)) from None def call(self, *args, transact_args=None): @@ -236,7 +242,7 @@ def call(self, *args, transact_args=None): args = fix_tx_args(args) try: return handle_w3_value(self._func(*args).call(transact_args)) - except web3.exceptions.ContractLogicError as ex: + except (web3.exceptions.ContractLogicError, ValueError) as ex: raise EthRevertException(str(ex)) from None def __call__(self, *args, transact_args=None): diff --git a/src/starkware/python/async_subprocess.py b/src/starkware/python/async_subprocess.py index 855bd7e2..194ccef5 100644 --- a/src/starkware/python/async_subprocess.py +++ b/src/starkware/python/async_subprocess.py @@ -1,11 +1,34 @@ import asyncio import sys -from typing import List, Union +from typing import List, Optional, Tuple, Union -async def async_check_output(args: Union[str, List[str]], shell: bool = False, cwd=None, env=None): +async def async_check_output( + args: Union[str, List[str]], shell: bool = False, cwd=None, env=None +) -> str: """ An async equivalent to subprocess.check_output(). + Note that this function returns a string (ascii decoded). + """ + decoded_stdout, decoded_stderr, returncode = await async_run_command( + args=args, shell=shell, cwd=cwd, env=env + ) + print(decoded_stderr, file=sys.stderr) + assert ( + returncode == 0 + ), f"""\ +stderr: {decoded_stderr} +stdout: {decoded_stdout} +""" + return decoded_stdout + + +async def async_run_command( + args: Union[str, List[str]], shell: bool = False, cwd=None, env=None +) -> Tuple[str, str, Optional[int]]: + """ + Runs a command. Returns the outputs - regular and error - as strings, + and the exit code of the command. """ if shell: assert isinstance(args, str), "args must be a string where shell=True." @@ -19,12 +42,6 @@ async def async_check_output(args: Union[str, List[str]], shell: bool = False, c stderr=asyncio.subprocess.PIPE, ) stdout_data, stderr_data = await proc.communicate() - decoded_stderr = stderr_data.decode() - print(decoded_stderr, file=sys.stderr) - assert ( - proc.returncode == 0 - ), f"""\ -stderr: {decoded_stderr} -stdout: {stdout_data.decode()} -""" - return stdout_data + decoded_stderr = stderr_data.decode("ascii").strip() + decoded_stdout = stdout_data.decode("ascii").strip() + return decoded_stdout, decoded_stderr, proc.returncode diff --git a/src/starkware/python/utils.py b/src/starkware/python/utils.py index d2346abb..8a83c8e1 100644 --- a/src/starkware/python/utils.py +++ b/src/starkware/python/utils.py @@ -12,8 +12,10 @@ from collections import UserDict from typing import ( Any, + AsyncContextManager, AsyncGenerator, AsyncIterable, + AsyncIterator, Awaitable, Callable, Coroutine, @@ -25,6 +27,7 @@ Mapping, Optional, Sequence, + Tuple, TypeVar, ) @@ -35,8 +38,7 @@ from starkware.python.utils_stub_module import * # noqa T = TypeVar("T") -TYield = TypeVar("TYield") -TSend = TypeVar("TSend") +TAsyncGenerator = TypeVar("TAsyncGenerator", bound=AsyncGenerator) NumType = TypeVar("NumType", int, float) HASH_BYTES = 32 @@ -538,7 +540,7 @@ def execute_coroutine_threadsafe( return future.result() -class aclosing(contextlib.AbstractAsyncContextManager, Generic[TYield, TSend]): +class aclosing(contextlib.AbstractAsyncContextManager, Generic[TAsyncGenerator]): """ Async context manager for safely finalizing an asynchronously cleaned-up resource such as an async generator, calling its 'aclose()' method. @@ -549,11 +551,33 @@ class aclosing(contextlib.AbstractAsyncContextManager, Generic[TYield, TSend]): See https://peps.python.org/pep-0533/ for more info. """ - def __init__(self, agen: AsyncGenerator[TYield, TSend]): + def __init__(self, agen: TAsyncGenerator): self.agen = agen - async def __aenter__(self) -> AsyncGenerator[TYield, TSend]: + async def __aenter__(self) -> TAsyncGenerator: return self.agen async def __aexit__(self, *exc_info): await self.agen.aclose() + + +def aclosing_context_manager( + function: Callable[..., TAsyncGenerator] +) -> Callable[..., AsyncContextManager[TAsyncGenerator]]: + """ + Wraps a function that returns an async generator with aclosing context manager. + """ + + def wrapper(*args, **kwargs): + return aclosing(agen=function(*args, **kwargs)) + + return wrapper + + +async def aenumerate(aiterable: AsyncIterable[T], start: int = 0) -> AsyncIterator[Tuple[int, T]]: + """ + Asynchronously enumerates an async iterable from a given start value. + """ + counter = itertools.count(start) + async for element in aiterable: + yield next(counter), element diff --git a/src/starkware/python/utils_test.py b/src/starkware/python/utils_test.py index d5a405f3..692fb8bb 100644 --- a/src/starkware/python/utils_test.py +++ b/src/starkware/python/utils_test.py @@ -1,4 +1,5 @@ import asyncio +import dataclasses import functools import random import re @@ -9,7 +10,7 @@ from starkware.python.utils import ( WriteOnceDict, - aclosing, + aclosing_context_manager, all_subclasses, as_non_optional, assert_exhausted, @@ -235,28 +236,36 @@ def sync_foo(x: int) -> int: @pytest.mark.asyncio async def test_aclosing(): - closed: bool + @dataclasses.dataclass + class IsClosed: + value: bool = False - async def gen_foo(): - nonlocal closed - closed = False + # Break an async loop before fully exhausting the generator, under the context manager. + @aclosing_context_manager + async def wrapped_foo_gen(is_closed: IsClosed): try: - for i in range(5): - yield i + yield finally: - closed = True + is_closed.value = True - # Break an async loop before fully exhausting the generator, under the context manager. - async with aclosing(gen_foo()) as gen: + is_closed = IsClosed() + async with wrapped_foo_gen(is_closed=is_closed) as gen: async for _ in gen: break - assert closed + assert is_closed.value # Same, but without the context manager - the generator is expected to be alive after the break. - gen = gen_foo() + async def foo_gen(is_closed: IsClosed): + try: + yield + finally: + is_closed.value = True + + is_closed = IsClosed() + gen = foo_gen(is_closed=is_closed) async for _ in gen: break - assert not closed + assert not is_closed.value await gen.aclose() # Close properly. diff --git a/src/starkware/starknet/business_logic/execution/CMakeLists.txt b/src/starkware/starknet/business_logic/execution/CMakeLists.txt index 4df9db40..304274cb 100644 --- a/src/starkware/starknet/business_logic/execution/CMakeLists.txt +++ b/src/starkware/starknet/business_logic/execution/CMakeLists.txt @@ -7,7 +7,7 @@ python_lib(starknet_transaction_execution_objects_lib LIBS cairo_vm_lib everest_definitions_lib - everest_internal_transaction_lib + everest_transaction_execution_objects_lib starknet_abi_lib starknet_business_logic_patricia_state_lib starknet_business_logic_state_lib diff --git a/src/starkware/starknet/business_logic/execution/execute_entry_point.py b/src/starkware/starknet/business_logic/execution/execute_entry_point.py index 5272e6da..62d42901 100644 --- a/src/starkware/starknet/business_logic/execution/execute_entry_point.py +++ b/src/starkware/starknet/business_logic/execution/execute_entry_point.py @@ -36,6 +36,7 @@ EntryPointType, ) from starkware.starkware_utils.error_handling import ( + ErrorCode, StarkException, stark_assert, wrap_with_stark_exception, @@ -225,7 +226,7 @@ def _run( verify_secure=True, ) except VmException as exception: - code = StarknetErrorCode.TRANSACTION_FAILED + code: ErrorCode = StarknetErrorCode.TRANSACTION_FAILED if isinstance(exception.inner_exc, HintException): hint_exception = exception.inner_exc diff --git a/src/starkware/starknet/business_logic/execution/execute_entry_point_base.py b/src/starkware/starknet/business_logic/execution/execute_entry_point_base.py index 1b90e441..aa5535ee 100644 --- a/src/starkware/starknet/business_logic/execution/execute_entry_point_base.py +++ b/src/starkware/starknet/business_logic/execution/execute_entry_point_base.py @@ -1,7 +1,7 @@ import dataclasses from abc import ABC, abstractmethod from dataclasses import field -from typing import List, Optional, TypeVar +from typing import List, Optional from starkware.starknet.business_logic.execution.objects import ( CallInfo, @@ -15,8 +15,6 @@ from starkware.starknet.services.api.contract_class import EntryPointType from starkware.starkware_utils.validated_dataclass import ValidatedDataclass -TExecuteEntryPoint = TypeVar("TExecuteEntryPoint", bound="ExecuteEntryPointBase") - # Mypy has a problem with dataclasses that contain unimplemented abstract methods. # See https://github.com/python/mypy/issues/5374 for details on this problem. diff --git a/src/starkware/starknet/business_logic/execution/gas_usage.py b/src/starkware/starknet/business_logic/execution/gas_usage.py index a7c76aef..70d6ac15 100644 --- a/src/starkware/starknet/business_logic/execution/gas_usage.py +++ b/src/starkware/starknet/business_logic/execution/gas_usage.py @@ -8,7 +8,7 @@ def calculate_tx_gas_usage( l2_to_l1_messages: List[L2ToL1MessageInfo], n_modified_contracts: int, - n_storage_writes: int, + n_storage_changes: int, l1_handler_payload_size: Optional[int], n_deployments: int, ) -> int: @@ -31,9 +31,9 @@ def calculate_tx_gas_usage( ) # Calculate the effect of the transaction on the output data availability segment. - residual_da_segment_length = get_da_segment_length( + residual_onchain_data_segment_length = get_onchain_data_segment_length( n_modified_contracts=n_modified_contracts, - n_storage_writes=n_storage_writes, + n_storage_changes=n_storage_changes, n_deployments=n_deployments, ) @@ -55,7 +55,7 @@ def calculate_tx_gas_usage( sharp_gas_usage = ( residual_message_segment_length * eth_gas_constants.SHARP_GAS_PER_MEMORY_WORD - + residual_da_segment_length * eth_gas_constants.SHARP_GAS_PER_MEMORY_WORD + + residual_onchain_data_segment_length * eth_gas_constants.SHARP_GAS_PER_MEMORY_WORD ) return starknet_gas_usage + sharp_gas_usage @@ -86,9 +86,9 @@ def get_message_segment_length( return message_segment_length -def get_da_segment_length( +def get_onchain_data_segment_length( n_modified_contracts: int, - n_storage_writes: int, + n_storage_changes: int, n_deployments: int, ) -> int: """ @@ -99,13 +99,13 @@ def get_da_segment_length( storage updates. """ # For each newly modified contract: contract address, number of modified storage cells. - da_segment_length = n_modified_contracts * 2 + onchain_data_segment_length = n_modified_contracts * 2 # For each modified storage cell: key, new value. - da_segment_length += n_storage_writes * 2 + onchain_data_segment_length += n_storage_changes * 2 # Add size of deployment info. - da_segment_length += n_deployments * constants.DEPLOYMENT_INFO_SIZE + onchain_data_segment_length += n_deployments * constants.DEPLOYMENT_INFO_SIZE - return da_segment_length + return onchain_data_segment_length def get_consumed_message_to_l2_emissions_cost(l1_handler_payload_size: Optional[int]) -> int: diff --git a/src/starkware/starknet/business_logic/execution/objects.py b/src/starkware/starknet/business_logic/execution/objects.py index f89297c2..72ba78fe 100644 --- a/src/starkware/starknet/business_logic/execution/objects.py +++ b/src/starkware/starknet/business_logic/execution/objects.py @@ -9,7 +9,9 @@ import marshmallow.fields as mfields import marshmallow_dataclass -from services.everest.business_logic.internal_transaction import EverestTransactionExecutionInfo +from services.everest.business_logic.transaction_execution_objects import ( + EverestTransactionExecutionInfo, +) from services.everest.definitions import fields as everest_fields from starkware.cairo.lang.vm.cairo_pie import ExecutionResources from starkware.cairo.lang.vm.utils import RunResources @@ -17,6 +19,7 @@ from starkware.starknet.business_logic.fact_state.contract_state_objects import StateSelector from starkware.starknet.business_logic.state.state import StorageEntry from starkware.starknet.definitions import constants, fields +from starkware.starknet.definitions.transaction_type import TransactionType from starkware.starknet.public.abi import CONSTRUCTOR_ENTRY_POINT_SELECTOR from starkware.starknet.services.api.contract_class import EntryPointType from starkware.starknet.services.api.gateway.transaction import DEFAULT_DECLARE_SENDER_ADDRESS @@ -391,10 +394,16 @@ class TransactionExecutionInfo(EverestTransactionExecutionInfo): # Actual resources the transaction is charged for, including L1 gas # and OS additional resources estimation. actual_resources: ResourcesMapping = field(metadata=fields.name_to_resources_metadata) + # Transaction type is used to determine the order of the calls. + tx_type: Optional[TransactionType] @property def non_optional_calls(self) -> Iterable[CallInfo]: - ordered_optional_calls = (self.validate_info, self.call_info, self.fee_transfer_info) + if self.tx_type is TransactionType.DEPLOY_ACCOUNT: + # In deploy account tx, validation will take place after execution of the constructor. + ordered_optional_calls = (self.call_info, self.validate_info, self.fee_transfer_info) + else: + ordered_optional_calls = (self.validate_info, self.call_info, self.fee_transfer_info) return tuple(call for call in ordered_optional_calls if call is not None) def get_state_selector(self) -> StateSelector: @@ -410,6 +419,7 @@ def get_visited_storage_entries(self) -> Set[StorageEntry]: def from_call_infos( cls, execute_call_info: Optional[CallInfo], + tx_type: Optional[TransactionType], validate_info: Optional[CallInfo] = None, fee_transfer_info: Optional[CallInfo] = None, ) -> "TransactionExecutionInfo": @@ -419,6 +429,59 @@ def from_call_infos( fee_transfer_info=fee_transfer_info, actual_fee=0, actual_resources={}, + tx_type=tx_type, + ) + + @classmethod + def empty(cls) -> "TransactionExecutionInfo": + return cls( + validate_info=None, + call_info=None, + fee_transfer_info=None, + actual_fee=0, + actual_resources={}, + tx_type=None, + ) + + @classmethod + def create_concurrent_stage_execution_info( + cls, + validate_info: Optional[CallInfo], + call_info: Optional[CallInfo], + actual_resources: ResourcesMapping, + tx_type: TransactionType, + ) -> "TransactionExecutionInfo": + """ + Returns TransactionExecutionInfo for the concurrent stage (without + fee_transfer_info and without fee). + """ + return cls( + validate_info=validate_info, + call_info=call_info, + fee_transfer_info=None, + actual_fee=0, + actual_resources=actual_resources, + tx_type=tx_type, + ) + + @classmethod + def from_concurrent_stage_execution_info( + cls, + concurrent_execution_info: "TransactionExecutionInfo", + actual_fee: int, + fee_transfer_info: Optional[CallInfo], + ) -> "TransactionExecutionInfo": + """ + Fills the given concurrent_execution_info with actual_fee and fee_transfer_info. + Used when the call infos (except for the fee handling) executed in the concurrent stage. + """ + return cls( + validate_info=concurrent_execution_info.validate_info, + call_info=concurrent_execution_info.call_info, + fee_transfer_info=fee_transfer_info, + actual_fee=actual_fee, + actual_resources=concurrent_execution_info.actual_resources, + tx_type=concurrent_execution_info.tx_type, ) def gen_call_iterator(self) -> Iterator[CallInfo]: diff --git a/src/starkware/starknet/business_logic/execution/os_resources.json b/src/starkware/starknet/business_logic/execution/os_resources.json index 078e13ab..6fa60a6c 100644 --- a/src/starkware/starknet/business_logic/execution/os_resources.json +++ b/src/starkware/starknet/business_logic/execution/os_resources.json @@ -27,7 +27,7 @@ "range_check_builtin": 16 }, "n_memory_holes": 0, - "n_steps": 796 + "n_steps": 798 }, "emit_event": { "builtin_instance_counter": {}, @@ -106,20 +106,28 @@ "range_check_builtin": 57 }, "n_memory_holes": 0, - "n_steps": 2330 + "n_steps": 2334 }, "DEPLOY": { "builtin_instance_counter": {}, "n_memory_holes": 0, "n_steps": 0 }, + "DEPLOY_ACCOUNT": { + "builtin_instance_counter": { + "pedersen_builtin": 23, + "range_check_builtin": 74 + }, + "n_memory_holes": 0, + "n_steps": 3096 + }, "INVOKE_FUNCTION": { "builtin_instance_counter": { "pedersen_builtin": 16, "range_check_builtin": 70 }, "n_memory_holes": 0, - "n_steps": 2833 + "n_steps": 2835 }, "L1_HANDLER": { "builtin_instance_counter": { diff --git a/src/starkware/starknet/business_logic/execution/os_usage.py b/src/starkware/starknet/business_logic/execution/os_usage.py index 92c722d2..cce3b2ef 100644 --- a/src/starkware/starknet/business_logic/execution/os_usage.py +++ b/src/starkware/starknet/business_logic/execution/os_usage.py @@ -42,6 +42,4 @@ def get_additional_os_resources( # Calculate the additional resources needed for the OS to run the given transaction; # i.e., the resources of the StarkNet OS function execute_transactions_inner(). - return os_additional_resources + os_resources.execute_txs_inner.get( - tx_type, ExecutionResources.empty() - ) + return os_additional_resources + os_resources.execute_txs_inner[tx_type] diff --git a/src/starkware/starknet/business_logic/fact_state/patricia_state.py b/src/starkware/starknet/business_logic/fact_state/patricia_state.py index 2450047c..2fb183af 100644 --- a/src/starkware/starknet/business_logic/fact_state/patricia_state.py +++ b/src/starkware/starknet/business_logic/fact_state/patricia_state.py @@ -4,13 +4,13 @@ ContractClassFact, ContractState, ) -from starkware.starknet.business_logic.state.state_api import StateReader -from starkware.starknet.definitions import fields -from starkware.starknet.definitions.error_codes import StarknetErrorCode +from starkware.starknet.business_logic.state.state_api import ( + StateReader, + get_stark_exception_on_undeclared_contract, +) from starkware.starknet.services.api.contract_class import ContractClass from starkware.starknet.storage.starknet_storage import StorageLeaf from starkware.starkware_utils.commitment_tree.patricia_tree.patricia_tree import PatriciaTree -from starkware.starkware_utils.error_handling import StarkException from starkware.storage.storage import FactFetchingContext @@ -30,7 +30,16 @@ def __init__(self, global_state_root: PatriciaTree, ffc: FactFetchingContext): # StateReader API. async def get_contract_class(self, class_hash: bytes) -> ContractClass: - return await self._fetch_contract_class(class_hash=class_hash) + contract_class_fact = await ContractClassFact.get( + storage=self.ffc.storage, suffix=class_hash + ) + + if contract_class_fact is None: + raise get_stark_exception_on_undeclared_contract(class_hash=class_hash) + + contract_class = contract_class_fact.contract_definition + contract_class.validate() + return contract_class async def get_class_hash_at(self, contract_address: int) -> bytes: contract_state = await self._get_contract_state(contract_address=contract_address) @@ -54,6 +63,16 @@ async def get_storage_at(self, contract_address: int, key: int) -> int: # Internal utilities. + async def _get_raw_contract_class(self, class_hash: bytes) -> bytes: + raw_contract_class_fact = await self.ffc.storage.get_value( + key=ContractClassFact.db_key(suffix=class_hash) + ) + + if raw_contract_class_fact is None: + raise get_stark_exception_on_undeclared_contract(class_hash=class_hash) + + return raw_contract_class_fact + async def _get_contract_state(self, contract_address: int) -> ContractState: if contract_address not in self.contract_states: self.contract_states[contract_address] = await self._fetch_contract_state( @@ -62,23 +81,6 @@ async def _get_contract_state(self, contract_address: int) -> ContractState: return self.contract_states[contract_address] - async def _fetch_contract_class(self, class_hash: bytes) -> ContractClass: - contract_class_fact = await ContractClassFact.get( - storage=self.ffc.storage, suffix=class_hash - ) - - if contract_class_fact is None: - formatted_class_hash = fields.class_hash_from_bytes(class_hash=class_hash) - raise StarkException( - code=StarknetErrorCode.UNDECLARED_CLASS, - message=f"Class with hash {formatted_class_hash} is not declared.", - ) - - contract_class = contract_class_fact.contract_definition - contract_class.validate() - - return contract_class - async def _fetch_contract_state(self, contract_address: int) -> ContractState: return await self.global_state_root.get_leaf( ffc=self.ffc, index=contract_address, fact_cls=ContractState diff --git a/src/starkware/starknet/business_logic/fact_state/state.py b/src/starkware/starknet/business_logic/fact_state/state.py index c4d1b386..d5cf6dc4 100644 --- a/src/starkware/starknet/business_logic/fact_state/state.py +++ b/src/starkware/starknet/business_logic/fact_state/state.py @@ -1,4 +1,3 @@ -import copy import logging from typing import Dict, Mapping, MutableMapping, Optional @@ -18,7 +17,6 @@ ) from starkware.starknet.business_logic.fact_state.patricia_state import PatriciaStateReader from starkware.starknet.business_logic.state.state import CachedState, StorageEntry -from starkware.starknet.business_logic.state.state_api import SyncState from starkware.starknet.business_logic.state.state_api_objects import BlockInfo from starkware.starknet.definitions.general_config import StarknetGeneralConfig from starkware.starknet.services.api.contract_class import ContractClass @@ -42,15 +40,11 @@ class ExecutionResourcesManager: def __init__( self, cairo_usage: ExecutionResources, - modified_contracts: Dict[int, None], syscall_counter: Dict[str, int], ): # The accumulated Cairo usage. self.cairo_usage = cairo_usage - # Addresses of contracts whose storage has changed. - self.modified_contracts = modified_contracts - # A mapping from system call to the cumulative times it was invoked. self.syscall_counter = syscall_counter @@ -60,7 +54,6 @@ def __init__( def empty(cls) -> "ExecutionResourcesManager": return cls( cairo_usage=ExecutionResources.empty(), - modified_contracts={}, syscall_counter={}, ) @@ -79,26 +72,14 @@ def __init__( self, parent_state: Optional["CarriedState"], state: CachedState, - resources_manager: ExecutionResourcesManager, ): """ Private constructor. Should only be called by _create_from_parent_state and create_unfilled class methods. """ super().__init__(parent_state=parent_state) - self.state = state - self._sync_state: Optional[SyncState] = None - - # The resources used throughout transaction stream processing. - self.resources_manager = resources_manager - - @property - def sync_state(self) -> SyncState: - assert self._sync_state is not None - return self._sync_state - # Alternative constructors. @classmethod @@ -115,10 +96,6 @@ def _create_from_parent_state(cls, parent_state: "CarriedState") -> "CarriedStat parent_state=parent_state, # Cached state. state=parent_state.state._copy(), - # Immutable objects. - # Chain maps - changes are inserted to the first map (at index 0); parent maps must not - # be modified. - resources_manager=copy.deepcopy(parent_state.resources_manager), ) @classmethod @@ -139,7 +116,6 @@ def create_unfilled( return cls( parent_state=None, state=state, - resources_manager=ExecutionResourcesManager.empty(), ) @classmethod @@ -188,7 +164,6 @@ def from_contracts( return cls( parent_state=None, state=state, - resources_manager=ExecutionResourcesManager.empty(), ) @property @@ -234,7 +209,6 @@ def _apply(self): """ # Apply state updates. parent_state = self.non_optional_parent_state - parent_state.resources_manager = self.resources_manager # Update CachedState. self.state._apply(parent=parent_state.state) @@ -372,6 +346,19 @@ class StateDiff(EverestStateDiff): storage_updates: Mapping[StorageEntry, int] block_info: BlockInfo + @classmethod + def empty(cls, block_info: BlockInfo): + """ + Returns an empty state diff object relative to the given block info. + """ + return cls( + class_hash_to_class={}, + address_to_class_hash={}, + address_to_nonce={}, + storage_updates={}, + block_info=block_info, + ) + @classmethod def from_cached_state(cls, cached_state: CachedState) -> "StateDiff": state_cache = cached_state.cache diff --git a/src/starkware/starknet/business_logic/state/state.py b/src/starkware/starknet/business_logic/state/state.py index 458facf3..0cd6822d 100644 --- a/src/starkware/starknet/business_logic/state/state.py +++ b/src/starkware/starknet/business_logic/state/state.py @@ -44,6 +44,12 @@ def get_contract_class(self, class_hash: bytes) -> ContractClass: coroutine=self.async_state.get_contract_class(class_hash=class_hash), loop=self.loop ) + def _get_raw_contract_class(self, class_hash: bytes) -> bytes: + return execute_coroutine_threadsafe( + coroutine=self.async_state._get_raw_contract_class(class_hash=class_hash), + loop=self.loop, + ) + def get_class_hash_at(self, contract_address: int) -> bytes: return execute_coroutine_threadsafe( coroutine=self.async_state.get_class_hash_at(contract_address=contract_address), @@ -100,11 +106,12 @@ class StateCache: def __init__(self): self.contract_classes: Dict[bytes, ContractClass] = {} + self.raw_contract_classes: Dict[bytes, bytes] = {} - # Reader's cached information. - self._class_hash_reads: Dict[int, bytes] = {} - self._nonce_reads: Dict[int, int] = {} - self._storage_reads: Dict[StorageEntry, int] = {} + # Reader's cached information; initial values, read before any write operation (per cell). + self._class_hash_initial_values: Dict[int, bytes] = {} + self._nonce_initial_values: Dict[int, int] = {} + self._storage_initial_values: Dict[StorageEntry, int] = {} # Writer's cached information. self._class_hash_writes: Dict[int, bytes] = {} @@ -115,12 +122,14 @@ def __init__(self): # Mappings from contract address to different attributes. self.address_to_class_hash: Mapping[int, bytes] = ChainMap( - self._class_hash_writes, self._class_hash_reads + self._class_hash_writes, self._class_hash_initial_values + ) + self.address_to_nonce: Mapping[int, int] = ChainMap( + self._nonce_writes, self._nonce_initial_values ) - self.address_to_nonce: Mapping[int, int] = ChainMap(self._nonce_writes, self._nonce_reads) # Mapping from (contract_address, key) to a value in the contract's storage. self.storage_view: Mapping[StorageEntry, int] = ChainMap( - self._storage_writes, self._storage_reads + self._storage_writes, self._storage_initial_values ) def update_writes_from_other(self, other: "StateCache"): @@ -173,29 +182,38 @@ async def get_contract_class(self, class_hash: bytes) -> ContractClass: return self.cache.contract_classes[class_hash] + async def _get_raw_contract_class(self, class_hash: bytes) -> bytes: + if class_hash not in self.cache.raw_contract_classes: + raw_contract_class = await self.state_reader._get_raw_contract_class( + class_hash=class_hash + ) + self.cache.raw_contract_classes[class_hash] = raw_contract_class + + return self.cache.raw_contract_classes[class_hash] + async def get_class_hash_at(self, contract_address: int) -> bytes: if contract_address not in self.cache.address_to_class_hash: class_hash = await self.state_reader.get_class_hash_at( contract_address=contract_address ) - self.cache._class_hash_reads[contract_address] = class_hash + self.cache._class_hash_initial_values[contract_address] = class_hash return self.cache.address_to_class_hash[contract_address] async def get_nonce_at(self, contract_address: int) -> int: if contract_address not in self.cache.address_to_nonce: - self.cache._nonce_reads[contract_address] = await self.state_reader.get_nonce_at( - contract_address=contract_address - ) + self.cache._nonce_initial_values[ + contract_address + ] = await self.state_reader.get_nonce_at(contract_address=contract_address) return self.cache.address_to_nonce[contract_address] async def get_storage_at(self, contract_address: int, key: int) -> int: address_key_pair = (contract_address, key) if address_key_pair not in self.cache.storage_view: - self.cache._storage_reads[address_key_pair] = await self.state_reader.get_storage_at( - contract_address=contract_address, key=key - ) + self.cache._storage_initial_values[ + address_key_pair + ] = await self.state_reader.get_storage_at(contract_address=contract_address, key=key) return self.cache.storage_view[address_key_pair] @@ -271,17 +289,25 @@ def get_contract_class(self, class_hash: bytes) -> ContractClass: return self.cache.contract_classes[class_hash] + def _get_raw_contract_class(self, class_hash: bytes) -> bytes: + if class_hash not in self.cache.raw_contract_classes: + self.cache.raw_contract_classes[class_hash] = self.state_reader._get_raw_contract_class( + class_hash=class_hash + ) + + return self.cache.raw_contract_classes[class_hash] + def get_class_hash_at(self, contract_address: int) -> bytes: if contract_address not in self.cache.address_to_class_hash: - self.cache._class_hash_reads[contract_address] = self.state_reader.get_class_hash_at( - contract_address=contract_address - ) + self.cache._class_hash_initial_values[ + contract_address + ] = self.state_reader.get_class_hash_at(contract_address=contract_address) return self.cache.address_to_class_hash[contract_address] def get_nonce_at(self, contract_address: int) -> int: if contract_address not in self.cache.address_to_nonce: - self.cache._nonce_reads[contract_address] = self.state_reader.get_nonce_at( + self.cache._nonce_initial_values[contract_address] = self.state_reader.get_nonce_at( contract_address=contract_address ) @@ -290,7 +316,7 @@ def get_nonce_at(self, contract_address: int) -> int: def get_storage_at(self, contract_address: int, key: int) -> int: address_key_pair = (contract_address, key) if address_key_pair not in self.cache.storage_view: - self.cache._storage_reads[address_key_pair] = self.state_reader.get_storage_at( + self.cache._storage_initial_values[address_key_pair] = self.state_reader.get_storage_at( contract_address=contract_address, key=key ) @@ -346,3 +372,100 @@ def read(self, address: int) -> int: def write(self, address: int, value: int): self.accessed_keys.add(address) self.state.set_storage_at(contract_address=self.contract_address, key=address, value=value) + + +class UpdatesTrackerState(SyncState): + """ + An implementation of the SyncState API that wraps another SyncState object and contains a cache. + All requests are delegated to the wrapped SyncState, and caches are maintained for storage reads + and writes. + + The goal of this implementation is to allow more precise and fair computation of the number of + storage-writes a single transaction preforms for the purposes of transaction fee calculation. + That is, if a given transaction writes to the same storage address multiple times, this should + be counted as a single storage-write. Additionally, if a transaction writes a value to storage + which is equal to the initial value previously contained in that address, then no change needs + to be done and this should not count as a storage-write. + """ + + def __init__(self, state: SyncState): + self.state = state + # Initial values read before any write operation (per storage cell). + self._storage_initial_values: Dict[StorageEntry, int] = {} + self._storage_writes: Dict[StorageEntry, int] = {} + + def get_storage_at(self, contract_address: int, key: int) -> int: + # Delegate the request to the actual state anyway (even if the value is already cached). + return_value = self.state.get_storage_at(contract_address=contract_address, key=key) + address_key_pair = (contract_address, key) + if not self._was_accessed(address_key_pair=address_key_pair): + # First access (read or write) to this cell; cache initial value. + self._storage_initial_values[address_key_pair] = return_value + + return return_value + + def set_storage_at(self, contract_address: int, key: int, value: int): + """ + This method writes to a storage cell and updates the cache accordingly. If this is the first + access to the cell (read or write), the method first reads the value at that cell and caches + it. + + This read operation is necessary for fee calculation. Because if the transaction writes a + value to storage that is identical to the value previously held at that address, then no + change is made to that cell and it does not count as a storage-change in fee calculation. + """ + address_key_pair = (contract_address, key) + if not self._was_accessed(address_key_pair=address_key_pair): + # First access (read or write) to this cell; cache initial value. + self._storage_initial_values[address_key_pair] = self.state.get_storage_at( + contract_address=contract_address, key=key + ) + + self._storage_writes[address_key_pair] = value + return self.state.set_storage_at(contract_address=contract_address, key=key, value=value) + + @property + def block_info(self) -> BlockInfo: + return self.state.block_info + + def update_block_info(self, block_info: BlockInfo): + return self.state.update_block_info(block_info=block_info) + + def _get_raw_contract_class(self, class_hash: bytes) -> bytes: + return self.state._get_raw_contract_class(class_hash=class_hash) + + def get_contract_class(self, class_hash: bytes) -> ContractClass: + return self.state.get_contract_class(class_hash=class_hash) + + def get_class_hash_at(self, contract_address: int) -> bytes: + return self.state.get_class_hash_at(contract_address=contract_address) + + def get_nonce_at(self, contract_address: int) -> int: + return self.state.get_nonce_at(contract_address=contract_address) + + def set_contract_class(self, class_hash: bytes, contract_class: ContractClass): + return self.state.set_contract_class(class_hash=class_hash, contract_class=contract_class) + + def deploy_contract(self, contract_address: int, class_hash: bytes): + return self.state.deploy_contract(contract_address=contract_address, class_hash=class_hash) + + def increment_nonce(self, contract_address: int): + return self.state.increment_nonce(contract_address=contract_address) + + def count_actual_storage_changes(self) -> Tuple[int, int]: + """ + Returns the number of storage changes done through this state, and the number of modified + contracts, where a contract is considered as modified if one or more of its storage cells + has changed. + """ + storage_updates = dict(self._storage_writes.items() - self._storage_initial_values.items()) + modified_contracts = { + contract_address for (contract_address, _key) in storage_updates.keys() + } + return (len(modified_contracts), len(storage_updates)) + + def _was_accessed(self, address_key_pair: Tuple[int, int]) -> bool: + return ( + address_key_pair in self._storage_initial_values + or address_key_pair in self._storage_writes + ) diff --git a/src/starkware/starknet/business_logic/state/state_api.py b/src/starkware/starknet/business_logic/state/state_api.py index 6f7b6143..34d3d813 100644 --- a/src/starkware/starknet/business_logic/state/state_api.py +++ b/src/starkware/starknet/business_logic/state/state_api.py @@ -2,7 +2,10 @@ from services.everest.business_logic.state_api import StateProxy from starkware.starknet.business_logic.state.state_api_objects import BlockInfo +from starkware.starknet.definitions import fields +from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.services.api.contract_class import ContractClass +from starkware.starkware_utils.error_handling import StarkException class StateReader(ABC): @@ -17,6 +20,13 @@ async def get_contract_class(self, class_hash: bytes) -> ContractClass: Raises an exception if said class was not declared. """ + @abstractmethod + async def _get_raw_contract_class(self, class_hash: bytes) -> bytes: + """ + Returns the raw bytes of the contract class object of the given class hash. + Raises an exception if said class was not declared. + """ + @abstractmethod async def get_class_hash_at(self, contract_address: int) -> bytes: """ @@ -89,6 +99,10 @@ class SyncStateReader(ABC): def get_contract_class(self, class_hash: bytes) -> ContractClass: pass + @abstractmethod + def _get_raw_contract_class(self, class_hash: bytes) -> bytes: + pass + @abstractmethod def get_class_hash_at(self, contract_address: int) -> bytes: pass @@ -133,3 +147,14 @@ def update_block_info(self, block_info: BlockInfo): @abstractmethod def set_storage_at(self, contract_address: int, key: int, value: int): pass + + +# Utilities. + + +def get_stark_exception_on_undeclared_contract(class_hash: bytes) -> StarkException: + formatted_class_hash = fields.class_hash_from_bytes(class_hash=class_hash) + return StarkException( + code=StarknetErrorCode.UNDECLARED_CLASS, + message=f"Class with hash {formatted_class_hash} is not declared.", + ) diff --git a/src/starkware/starknet/business_logic/state/state_api_objects.py b/src/starkware/starknet/business_logic/state/state_api_objects.py index 55dc1d1b..99eae28f 100644 --- a/src/starkware/starknet/business_logic/state/state_api_objects.py +++ b/src/starkware/starknet/business_logic/state/state_api_objects.py @@ -28,7 +28,7 @@ class BlockInfo(ValidatedMarshmallowDataclass): # The sequencer address of this block. sequencer_address: Optional[int] = field(metadata=fields.optional_sequencer_address_metadata) - # The version of StarkNet system (e.g. "0.10.0"). + # The version of StarkNet system (e.g. "0.10.1"). starknet_version: Optional[str] = field(metadata=fields.starknet_version_metadata) @classmethod diff --git a/src/starkware/starknet/business_logic/transaction/objects.py b/src/starkware/starknet/business_logic/transaction/objects.py index ea712ad0..b8c84aab 100644 --- a/src/starkware/starknet/business_logic/transaction/objects.py +++ b/src/starkware/starknet/business_logic/transaction/objects.py @@ -21,6 +21,7 @@ ) from starkware.starknet.business_logic.fact_state.contract_state_objects import StateSelector from starkware.starknet.business_logic.fact_state.state import ExecutionResourcesManager +from starkware.starknet.business_logic.state.state import UpdatesTrackerState from starkware.starknet.business_logic.state.state_api import SyncState from starkware.starknet.business_logic.state.state_api_objects import BlockInfo from starkware.starknet.business_logic.transaction.fee import calculate_tx_fee, execute_fee_transfer @@ -39,6 +40,7 @@ from starkware.starknet.core.os.transaction_hash.transaction_hash import ( TransactionHashPrefix, calculate_declare_transaction_hash, + calculate_deploy_account_transaction_hash, calculate_deploy_transaction_hash, calculate_transaction_hash_common, ) @@ -52,6 +54,7 @@ DEFAULT_DECLARE_SENDER_ADDRESS, Declare, Deploy, + DeployAccount, InvokeFunction, Transaction, ) @@ -139,7 +142,7 @@ def from_external( assert isinstance(external_tx, Transaction) assert isinstance(general_config, StarknetGeneralConfig) - internal_cls = InternalTransaction.external_to_internal_cls.get(type(external_tx)) + internal_cls = cls.external_to_internal_cls.get(type(external_tx)) if internal_cls is None: raise NotImplementedError(f"Unsupported transaction type {type(external_tx).__name__}.") @@ -174,6 +177,15 @@ async def apply_state_updates( assert isinstance(tx_execution_info, TransactionExecutionInfo) return tx_execution_info + @abstractmethod + def _apply_specific_sequential_changes( + self, + state: SyncState, + general_config: StarknetGeneralConfig, + concurrent_execution_info: TransactionExecutionInfo, + ) -> TransactionExecutionInfo: + pass + class SyntheticTransaction(InternalStateTransaction): """ @@ -203,8 +215,11 @@ class InitializeBlockInfo(SyntheticTransaction): block_info: BlockInfo tx_type: ClassVar[TransactionType] = TransactionType.INITIALIZE_BLOCK_INFO - def _apply_specific_state_updates( - self, state: SyncState, general_config: StarknetGeneralConfig + def _apply_specific_sequential_changes( + self, + state: SyncState, + general_config: StarknetGeneralConfig, + concurrent_execution_info: TransactionExecutionInfo, ) -> Optional[TransactionExecutionInfo]: # Validate progress is legal. state.block_info.validate_legal_progress(next_block_info=self.block_info) @@ -214,6 +229,11 @@ def _apply_specific_state_updates( return None + def _apply_specific_concurrent_changes( + self, state: UpdatesTrackerState, general_config: StarknetGeneralConfig + ) -> TransactionExecutionInfo: + return TransactionExecutionInfo.empty() + def get_state_selector(self, general_config: Config) -> StateSelector: return StateSelector.empty() @@ -228,7 +248,7 @@ class InternalAccountTransaction(InternalTransaction): # signed by the account contract. # This field allows invalidating old transactions, whenever the meaning of the other # transaction fields is changed (in the OS). - version: int = field(metadata=fields.tx_version_metadata) + version: int = field(metadata=fields.non_required_tx_version_metadata) # The maximal fee to be paid in Wei for the execution. max_fee: int = field(metadata=fields.fee_metadata) signature: List[int] = field(metadata=fields.signature_metadata) @@ -261,7 +281,7 @@ def validate_entry_point_selector(cls) -> int: """ def verify_version(self): - verify_version(version=self.version, only_query=False) + verify_version(version=self.version, only_query=False, old_supported_versions=[0]) def run_validate_entrypoint( self, @@ -291,7 +311,7 @@ def run_validate_entrypoint( n_steps=general_config.validate_max_n_steps ), ) - verify_no_calls_to_other_contracts(call_info=call_info) + verify_no_calls_to_other_contracts(call_info=call_info, function_name="'validate'") return call_info @@ -351,6 +371,27 @@ def _handle_nonce(self, state: SyncState): # transactions. state.increment_nonce(contract_address=self.account_contract_address) + def _apply_specific_sequential_changes( + self, + state: SyncState, + general_config: StarknetGeneralConfig, + concurrent_execution_info: TransactionExecutionInfo, + ) -> TransactionExecutionInfo: + self._handle_nonce(state=state) + + # Handle fee. + fee_transfer_info, actual_fee = self.charge_fee( + state=state, + general_config=general_config, + resources=concurrent_execution_info.actual_resources, + ) + + return TransactionExecutionInfo.from_concurrent_stage_execution_info( + concurrent_execution_info=concurrent_execution_info, + fee_transfer_info=fee_transfer_info, + actual_fee=actual_fee, + ) + @marshmallow_dataclass.dataclass(frozen=True) class InternalDeclare(InternalAccountTransaction): @@ -502,15 +543,14 @@ def get_state_selector(self, general_config: Config) -> StateSelector: class_hashes=[self.class_hash], ) - def _apply_specific_state_updates( - self, state: SyncState, general_config: StarknetGeneralConfig + def _apply_specific_concurrent_changes( + self, state: UpdatesTrackerState, general_config: StarknetGeneralConfig ) -> TransactionExecutionInfo: # Reject unsupported versions. This is necessary (in addition to the gateway's check) # since an old transaction might still reach here, e.g., in case of a re-org. self.verify_version() # Validate transaction. - self._handle_nonce(state=state) resources_manager = ExecutionResourcesManager.empty() validate_info = self.run_validate_entrypoint( state=state, @@ -520,20 +560,249 @@ def _apply_specific_state_updates( # Handle fee. actual_resources = calculate_tx_resources( - resources_manager=resources_manager, call_infos=[validate_info], tx_type=self.tx_type - ) - fee_transfer_info, actual_fee = self.charge_fee( - state=state, general_config=general_config, resources=actual_resources + state=state, + resources_manager=resources_manager, + call_infos=[validate_info], + tx_type=self.tx_type, ) - return TransactionExecutionInfo( + return TransactionExecutionInfo.create_concurrent_stage_execution_info( validate_info=validate_info, call_info=None, - fee_transfer_info=fee_transfer_info, - actual_fee=actual_fee, actual_resources=actual_resources, + tx_type=self.tx_type, + ) + + +@marshmallow_dataclass.dataclass(frozen=True) +class InternalDeployAccount(InternalAccountTransaction): + """ + Internal version of the DeployAccount transaction (deployment of StarkNet account contracts). + """ + + contract_address: int = field(metadata=fields.contract_address_metadata) + contract_address_salt: int = field(metadata=fields.contract_address_salt_metadata) + class_hash: bytes = field(metadata=fields.class_hash_metadata) + constructor_calldata: List[int] = field(metadata=fields.call_data_metadata) + version: int = field(metadata=fields.tx_version_metadata) + # Repeat `nonce` to narrow its type to non-optional int. + nonce: int = field(metadata=fields.nonce_metadata) + + # Class variables. + tx_type: ClassVar[TransactionType] = TransactionType.DEPLOY_ACCOUNT + related_external_cls: ClassVar[Type[Transaction]] = DeployAccount + validate_entry_point_selector: ClassVar[int] = starknet_abi.VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR + + @property + def account_contract_address(self) -> int: + return self.contract_address + + @property + def validate_entrypoint_calldata(self) -> List[int]: + # '__validate_deploy__' is expected to get the arguments: + # class_hash, salt, constructor_calldata. + return [ + from_bytes(self.class_hash), + self.contract_address_salt, + *self.constructor_calldata, + ] + + def verify_version(self): + verify_version(version=self.version, only_query=False, old_supported_versions=[]) + + @classmethod + def create( + cls, + class_hash: int, + max_fee: int, + version: int, + nonce: int, + constructor_calldata: List[int], + signature: List[int], + contract_address_salt: int, + chain_id: int, + ) -> "InternalDeployAccount": + contract_address = calculate_contract_address_from_hash( + salt=contract_address_salt, + class_hash=class_hash, + constructor_calldata=constructor_calldata, + deployer_address=0, + ) + internal_deploy_account = cls( + contract_address=contract_address, + contract_address_salt=contract_address_salt, + constructor_calldata=constructor_calldata, + class_hash=to_bytes(class_hash), + version=version, + max_fee=max_fee, + signature=signature, + nonce=nonce, + hash_value=calculate_deploy_account_transaction_hash( + version=version, + contract_address=contract_address, + class_hash=class_hash, + constructor_calldata=constructor_calldata, + max_fee=max_fee, + nonce=nonce, + salt=contract_address_salt, + chain_id=chain_id, + ), + ) + internal_deploy_account.verify_version() + return internal_deploy_account + + @classmethod + async def create_for_testing( + cls, + contract_class: ContractClass, + max_fee: int, + contract_address_salt: int = 0, + constructor_calldata: Optional[List[int]] = None, + chain_id: int = 0, + signature: Optional[List[int]] = None, + ) -> "InternalDeployAccount": + return InternalDeployAccount.create( + class_hash=compute_class_hash(contract_class=contract_class), + contract_address_salt=contract_address_salt, + constructor_calldata=[] if constructor_calldata is None else constructor_calldata, + nonce=0, + max_fee=max_fee, + version=constants.TRANSACTION_VERSION, + signature=[] if signature is None else signature, + chain_id=chain_id, + ) + + @classmethod + def _specific_from_external( + cls, external_tx: Transaction, general_config: StarknetGeneralConfig + ) -> "InternalDeployAccount": + assert isinstance(external_tx, DeployAccount) + return cls.create( + class_hash=external_tx.class_hash, + max_fee=external_tx.max_fee, + version=external_tx.version, + nonce=external_tx.nonce, + constructor_calldata=external_tx.constructor_calldata, + signature=external_tx.signature, + contract_address_salt=external_tx.contract_address_salt, + chain_id=general_config.chain_id.value, + ) + + def to_external(self) -> DeployAccount: + return DeployAccount( + version=self.version, + max_fee=self.max_fee, + signature=self.signature, + nonce=self.nonce, + contract_address_salt=self.contract_address_salt, + class_hash=from_bytes(self.class_hash), + constructor_calldata=self.constructor_calldata, + ) + + def get_state_selector(self, general_config: Config) -> StateSelector: + """ + Returns the state selector of the transaction (i.e., subset of state commitment tree leaves + it affects). + """ + return StateSelector.create( + contract_addresses=[self.contract_address], class_hashes=[self.class_hash] + ) + + def _apply_specific_concurrent_changes( + self, state: UpdatesTrackerState, general_config: StarknetGeneralConfig + ) -> TransactionExecutionInfo: + """ + Adds the deployed contract to the global commitment tree state. + """ + # Reject unsupported versions. This is necessary (in addition to the gateway's check) + # since an old transaction might still reach here, e.g., in case of a re-org. + self.verify_version() + + # Ensure the class is declared (by reading it). + contract_class = state.get_contract_class(class_hash=self.class_hash) + + # Deploy. + state.deploy_contract(contract_address=self.contract_address, class_hash=self.class_hash) + + # Run the constructor. + resources_manager = ExecutionResourcesManager.empty() + constructor_call_info = self.handle_constructor( + contract_class=contract_class, + state=state, + general_config=general_config, + resources_manager=resources_manager, + ) + + # Validate transaction. + validate_info = self.run_validate_entrypoint( + state=state, resources_manager=resources_manager, general_config=general_config + ) + + actual_resources = calculate_tx_resources( + state=state, + resources_manager=resources_manager, + call_infos=[constructor_call_info, validate_info], + tx_type=self.tx_type, + ) + + return TransactionExecutionInfo.create_concurrent_stage_execution_info( + validate_info=validate_info, + call_info=constructor_call_info, + actual_resources=actual_resources, + tx_type=self.tx_type, + ) + + def handle_constructor( + self, + contract_class: ContractClass, + state: UpdatesTrackerState, + general_config: StarknetGeneralConfig, + resources_manager: ExecutionResourcesManager, + ) -> CallInfo: + n_ctors = len(contract_class.entry_points_by_type[EntryPointType.CONSTRUCTOR]) + if n_ctors == 0: + stark_assert( + len(self.constructor_calldata) == 0, + code=StarknetErrorCode.TRANSACTION_FAILED, + message="Cannot pass calldata to a contract with no constructor.", + ) + return CallInfo.empty_constructor_call( + contract_address=self.contract_address, + caller_address=0, + class_hash=self.class_hash, + ) + else: + return self.run_constructor_entrypoint( + state=state, general_config=general_config, resources_manager=resources_manager + ) + + def run_constructor_entrypoint( + self, + state: UpdatesTrackerState, + general_config: StarknetGeneralConfig, + resources_manager: ExecutionResourcesManager, + ) -> CallInfo: + call = ExecuteEntryPoint.create( + contract_address=self.contract_address, + entry_point_selector=starknet_abi.CONSTRUCTOR_ENTRY_POINT_SELECTOR, + entry_point_type=EntryPointType.CONSTRUCTOR, + calldata=self.constructor_calldata, + caller_address=0, + ) + constructor_call_info = call.execute( + state=state, + resources_manager=resources_manager, + general_config=general_config, + tx_execution_context=self.get_execution_context( + n_steps=general_config.validate_max_n_steps + ), + ) + verify_no_calls_to_other_contracts( + call_info=constructor_call_info, function_name="DeployAccount's constructor" ) + return constructor_call_info + @marshmallow_dataclass.dataclass(frozen=True) class InternalDeploy(InternalTransaction): @@ -546,7 +815,7 @@ class InternalDeploy(InternalTransaction): # signed by the account contract. # This field allows invalidating old transactions, whenever the meaning of the other # transaction fields is changed (in the OS). - version: int = field(metadata=fields.tx_version_metadata) + version: int = field(metadata=fields.non_required_tx_version_metadata) contract_address: int = field(metadata=fields.contract_address_metadata) contract_address_salt: int = field(metadata=fields.contract_address_salt_metadata) contract_hash: bytes = field(metadata=fields.non_required_class_hash_metadata) @@ -580,7 +849,7 @@ def create( chain_id: int, version: int, ): - verify_version(version=version, only_query=False) + verify_version(version=version, only_query=False, old_supported_versions=[0]) class_hash = compute_class_hash(contract_class=contract_class) contract_address = calculate_contract_address_from_hash( @@ -638,6 +907,10 @@ def _specific_from_external( version=external_tx.version, ) + @property + def class_hash(self) -> bytes: + return self.contract_hash + def to_external(self) -> Deploy: raise NotImplementedError("Cannot convert internal deploy transaction to external object.") @@ -648,26 +921,34 @@ def get_state_selector(self, general_config: Config) -> StateSelector: """ return StateSelector.create(contract_addresses=[self.contract_address], class_hashes=[]) - def _apply_specific_state_updates( - self, state: SyncState, general_config: StarknetGeneralConfig + def _apply_specific_concurrent_changes( + self, state: UpdatesTrackerState, general_config: StarknetGeneralConfig ) -> TransactionExecutionInfo: """ Adds the deployed contract to the global commitment tree state. """ # Reject unsupported versions. This is necessary (in addition to the gateway's check) # since an old transaction might still reach here, e.g., in case of a re-org. - verify_version(version=self.version, only_query=False) + verify_version(version=self.version, only_query=False, old_supported_versions=[0]) # Execute transaction. state.deploy_contract(contract_address=self.contract_address, class_hash=self.contract_hash) contract_class = state.get_contract_class(class_hash=self.contract_hash) n_ctors = len(contract_class.entry_points_by_type[EntryPointType.CONSTRUCTOR]) if n_ctors == 0: - return self.handle_empty_constructor() + return self.handle_empty_constructor(state=state) else: return self.invoke_constructor(state=state, general_config=general_config) - def handle_empty_constructor(self) -> TransactionExecutionInfo: + def _apply_specific_sequential_changes( + self, + state: SyncState, + general_config: StarknetGeneralConfig, + concurrent_execution_info: TransactionExecutionInfo, + ) -> TransactionExecutionInfo: + return concurrent_execution_info + + def handle_empty_constructor(self, state: UpdatesTrackerState) -> TransactionExecutionInfo: stark_assert( len(self.constructor_calldata) == 0, code=StarknetErrorCode.TRANSACTION_FAILED, @@ -681,19 +962,21 @@ def handle_empty_constructor(self) -> TransactionExecutionInfo: ) resources_manager = ExecutionResourcesManager.empty() actual_resources = calculate_tx_resources( - resources_manager=resources_manager, call_infos=[call_info], tx_type=self.tx_type + state=state, + resources_manager=resources_manager, + call_infos=[call_info], + tx_type=self.tx_type, ) - return TransactionExecutionInfo( + return TransactionExecutionInfo.create_concurrent_stage_execution_info( validate_info=None, call_info=call_info, - fee_transfer_info=None, - actual_fee=0, actual_resources=actual_resources, + tx_type=self.tx_type, ) def invoke_constructor( - self, state: SyncState, general_config: StarknetGeneralConfig + self, state: UpdatesTrackerState, general_config: StarknetGeneralConfig ) -> TransactionExecutionInfo: call = ExecuteEntryPoint.create( contract_address=self.contract_address, @@ -720,15 +1003,17 @@ def invoke_constructor( tx_execution_context=tx_execution_context, ) actual_resources = calculate_tx_resources( - resources_manager=resources_manager, call_infos=[call_info], tx_type=self.tx_type + state=state, + resources_manager=resources_manager, + call_infos=[call_info], + tx_type=self.tx_type, ) - return TransactionExecutionInfo( + return TransactionExecutionInfo.create_concurrent_stage_execution_info( validate_info=None, call_info=call_info, - fee_transfer_info=None, - actual_fee=0, actual_resources=actual_resources, + tx_type=self.tx_type, ) @@ -826,7 +1111,7 @@ def create_wrapped_with_account( ): """ Creates an account contract invocation to the 'dummy_account' - test contract at address 'account_address'. + test contract at address 'account_address'; should only be used in tests. """ return cls.create( @@ -949,8 +1234,8 @@ def get_state_selector(self, general_config: Config) -> StateSelector: return StateSelector.create(contract_addresses=contract_addresses, class_hashes=[]) - def _apply_specific_state_updates( - self, state: SyncState, general_config: StarknetGeneralConfig + def _apply_specific_concurrent_changes( + self, state: UpdatesTrackerState, general_config: StarknetGeneralConfig ) -> TransactionExecutionInfo: """ Applies self to 'state' by executing the entry point and charging fee for it (if needed). @@ -960,7 +1245,6 @@ def _apply_specific_state_updates( self.verify_version() # Validate transaction. - self._handle_nonce(state=state) resources_manager = ExecutionResourcesManager.empty() validate_info = self.run_validate_entrypoint( state=state, @@ -977,22 +1261,17 @@ def _apply_specific_state_updates( # Handle fee. actual_resources = calculate_tx_resources( + state=state, resources_manager=resources_manager, call_infos=[call_info, validate_info], tx_type=self.tx_type, ) - fee_transfer_info, actual_fee = self.charge_fee( - state=state, - general_config=general_config, - resources=actual_resources, - ) - return TransactionExecutionInfo( + return TransactionExecutionInfo.create_concurrent_stage_execution_info( validate_info=validate_info, call_info=call_info, - fee_transfer_info=fee_transfer_info, - actual_fee=actual_fee, actual_resources=actual_resources, + tx_type=self.tx_type, ) def run_validate_entrypoint( @@ -1051,7 +1330,7 @@ class InternalL1Handler(InternalTransaction): calldata: List[int] = field(metadata=fields.call_data_metadata) # A unique nonce, added by the StarkNet core contract on L1. Guarantees a unique # hash_value of transactions. - nonce: int = field(metadata=fields.nonce_metadata) + nonce: Optional[int] = field(metadata=fields.optional_nonce_metadata) # Class variables. tx_type: ClassVar[TransactionType] = TransactionType.L1_HANDLER @@ -1061,6 +1340,22 @@ class InternalL1Handler(InternalTransaction): def related_external_cls(cls) -> Type[Transaction]: raise NotImplementedError("InternalL1Handler does not have a corresponding external class.") + @marshmallow.decorators.pre_load + def remove_deprecated_fields( + self, data: Dict[str, Any], many: bool, **kwargs + ) -> Dict[str, Any]: + for deprecated_field in ( + "entry_point_type", + "max_fee", + "signature", + "version", + "caller_address", + "code_address", + ): + data.pop(deprecated_field, None) + + return data + @classmethod def _specific_from_external( cls, external_tx: Transaction, general_config: StarknetGeneralConfig @@ -1122,8 +1417,8 @@ def get_state_selector(self, general_config: Config) -> StateSelector: """ return StateSelector.create(contract_addresses=[self.contract_address], class_hashes=[]) - def _apply_specific_state_updates( - self, state: SyncState, general_config: StarknetGeneralConfig + def _apply_specific_concurrent_changes( + self, state: UpdatesTrackerState, general_config: StarknetGeneralConfig ) -> TransactionExecutionInfo: """ Applies self to 'state' by executing the L1-handler entry point. @@ -1146,28 +1441,37 @@ def _apply_specific_state_updates( n_steps=general_config.invoke_tx_max_n_steps ), ) + actual_resources = calculate_tx_resources( + state=state, resources_manager=resources_manager, call_infos=[call_info], tx_type=self.tx_type, l1_handler_payload_size=self.get_payload_size(), ) - return TransactionExecutionInfo( + return TransactionExecutionInfo.create_concurrent_stage_execution_info( validate_info=None, call_info=call_info, - fee_transfer_info=None, - actual_fee=0, actual_resources=actual_resources, + tx_type=self.tx_type, ) + def _apply_specific_sequential_changes( + self, + state: SyncState, + general_config: StarknetGeneralConfig, + concurrent_execution_info: TransactionExecutionInfo, + ) -> TransactionExecutionInfo: + return concurrent_execution_info + def get_execution_context(self, n_steps: int) -> TransactionExecutionContext: return TransactionExecutionContext.create( account_contract_address=self.contract_address, transaction_hash=self.hash_value, signature=[], max_fee=0, - nonce=self.nonce, + nonce=as_non_optional(self.nonce), n_steps=n_steps, version=constants.L1_HANDLER_VERSION, ) @@ -1195,6 +1499,7 @@ class InternalTransactionSchema(OneOfSchema): type_schemas: Dict[str, Type[marshmallow.Schema]] = { TransactionType.DECLARE.name: InternalDeclare.Schema, TransactionType.DEPLOY.name: InternalDeploy.Schema, + TransactionType.DEPLOY_ACCOUNT.name: InternalDeployAccount.Schema, TransactionType.INVOKE_FUNCTION.name: InternalInvokeFunction.Schema, TransactionType.L1_HANDLER.name: InternalL1Handler.Schema, } @@ -1202,6 +1507,17 @@ class InternalTransactionSchema(OneOfSchema): def get_obj_type(self, obj: InternalTransaction) -> str: return obj.tx_type.name + def get_data_type(self, data: Dict[str, Any]) -> str: + data_type = data.get(self.type_field) + if ( + data_type == TransactionType.INVOKE_FUNCTION.name + and data.get("entry_point_type") == TransactionType.L1_HANDLER.name + ): + data.pop(self.type_field) + return TransactionType.L1_HANDLER.name + + return super().get_data_type(data=data) + InternalTransaction.Schema = InternalTransactionSchema diff --git a/src/starkware/starknet/business_logic/transaction/state_objects.py b/src/starkware/starknet/business_logic/transaction/state_objects.py index d294a9a9..8f374426 100644 --- a/src/starkware/starknet/business_logic/transaction/state_objects.py +++ b/src/starkware/starknet/business_logic/transaction/state_objects.py @@ -1,24 +1,24 @@ import asyncio import functools import logging -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Iterable, Optional, cast from services.everest.business_logic.internal_transaction import EverestInternalStateTransaction from services.everest.business_logic.state_api import StateProxy from starkware.starknet.business_logic.execution.objects import TransactionExecutionInfo from starkware.starknet.business_logic.fact_state.contract_state_objects import StateSelector -from starkware.starknet.business_logic.state.state import StateSyncifier +from starkware.starknet.business_logic.state.state import StateSyncifier, UpdatesTrackerState from starkware.starknet.business_logic.state.state_api import State, SyncState from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.definitions.general_config import StarknetGeneralConfig from starkware.starkware_utils.config_base import Config -from starkware.starkware_utils.error_handling import StarkException +from starkware.starkware_utils.error_handling import wrap_with_stark_exception logger = logging.getLogger(__name__) -class InternalStateTransaction(EverestInternalStateTransaction): +class InternalStateTransaction(EverestInternalStateTransaction, ABC): """ StarkNet internal state transaction. This is the API of transactions that update the state, @@ -72,25 +72,70 @@ def sync_apply_state_updates( assert isinstance(state, SyncState) assert isinstance(general_config, StarknetGeneralConfig) - try: - execution_info = self._apply_specific_state_updates( - state=state, general_config=general_config + concurrent_execution_info = self.apply_concurrent_changes( + state=state, general_config=general_config + ) + return self.apply_sequential_changes( + state=state, + general_config=general_config, + concurrent_execution_info=concurrent_execution_info, + ) + + def apply_concurrent_changes( + self, state: SyncState, general_config: StarknetGeneralConfig + ) -> TransactionExecutionInfo: + """ + Applies changes that can be efficiently done in concurrent execution. + Returns a partial execution info. + """ + with wrap_with_stark_exception( + code=StarknetErrorCode.UNEXPECTED_FAILURE, + logger=logger, + exception_types=[Exception], + ): + return self._apply_specific_concurrent_changes( + state=UpdatesTrackerState(state=state), general_config=general_config ) - except StarkException: - # Raise StarkException-s as-is, so failure information is not lost. - raise - except Exception as exception: - # Wrap all exceptions with StarkException, so the Batcher can continue running - # even after unexpected errors. - logger.error(f"Unexpected failure; exception details: {exception}.", exc_info=True) - raise StarkException( - code=StarknetErrorCode.UNEXPECTED_FAILURE, message=str(exception) - ) from exception - return execution_info + def apply_sequential_changes( + self, + state: SyncState, + general_config: StarknetGeneralConfig, + concurrent_execution_info: TransactionExecutionInfo, + ) -> Optional[TransactionExecutionInfo]: + """ + Applies the parts of the transaction needed to be executed sequentially to enable + efficient concurrency, as they are likely to collide in a concurrent execution, + for example, access to bottleneck storage cells such as the sequencer balance). + """ + with wrap_with_stark_exception( + code=StarknetErrorCode.UNEXPECTED_FAILURE, + logger=logger, + exception_types=[Exception], + ): + return self._apply_specific_sequential_changes( + state=state, + general_config=general_config, + concurrent_execution_info=concurrent_execution_info, + ) @abstractmethod - def _apply_specific_state_updates( - self, state: SyncState, general_config: StarknetGeneralConfig + def _apply_specific_concurrent_changes( + self, state: UpdatesTrackerState, general_config: StarknetGeneralConfig + ) -> TransactionExecutionInfo: + """ + A specific implementation of apply_concurrent_changes for each internal transaction. + See apply_concurrent_changes. + """ + + @abstractmethod + def _apply_specific_sequential_changes( + self, + state: SyncState, + general_config: StarknetGeneralConfig, + concurrent_execution_info: TransactionExecutionInfo, ) -> Optional[TransactionExecutionInfo]: - pass + """ + A specific implementation of apply_sequential_changes for each internal transaction. + See apply_sequential_changes. + """ diff --git a/src/starkware/starknet/business_logic/utils.py b/src/starkware/starknet/business_logic/utils.py index b4075495..f63bc9f4 100644 --- a/src/starkware/starknet/business_logic/utils.py +++ b/src/starkware/starknet/business_logic/utils.py @@ -14,6 +14,7 @@ from starkware.starknet.business_logic.execution.os_usage import get_additional_os_resources from starkware.starknet.business_logic.fact_state.contract_state_objects import ContractClassFact from starkware.starknet.business_logic.fact_state.state import ExecutionResourcesManager +from starkware.starknet.business_logic.state.state import UpdatesTrackerState from starkware.starknet.business_logic.state.state_api import SyncState from starkware.starknet.definitions import constants, fields from starkware.starknet.definitions.error_codes import StarknetErrorCode @@ -58,7 +59,7 @@ def get_return_values(runner: CairoFunctionRunner) -> List[int]: return cast(List[int], values) -def verify_version(version: int, only_query: bool): +def verify_version(version: int, only_query: bool, old_supported_versions: List[int]): """ Validates the given transaction version. @@ -67,7 +68,7 @@ def verify_version(version: int, only_query: bool): being invoked in the StarkNet OS. """ assert constants.TRANSACTION_VERSION == 1 - allowed_versions = [0, constants.TRANSACTION_VERSION] + allowed_versions = [*old_supported_versions, constants.TRANSACTION_VERSION] if only_query: error_code = StarknetErrorCode.INVALID_TRANSACTION_QUERYING_VERSION allowed_versions += [constants.QUERY_VERSION_BASE + v for v in allowed_versions] @@ -156,6 +157,7 @@ def calculate_tx_resources( resources_manager: ExecutionResourcesManager, call_infos: Iterable[Optional[CallInfo]], tx_type: TransactionType, + state: UpdatesTrackerState, l1_handler_payload_size: Optional[int] = None, ) -> ResourcesMapping: """ @@ -164,11 +166,9 @@ def calculate_tx_resources( Used for transaction fee; calculation is made as if the transaction is the first in batch, for consistency. """ - # Number of modified contracts by the most recently applied-on-state transaction. - n_modified_contracts_by_tx = len(resources_manager.modified_contracts.keys()) + (n_modified_contracts, n_storage_changes) = state.count_actual_storage_changes() non_optional_call_infos = [call for call in call_infos if call is not None] - tx_syscall_counter = resources_manager.syscall_counter n_deployments = 0 for call_info in non_optional_call_infos: n_deployments += get_call_n_deployments(call_info=call_info) @@ -179,14 +179,14 @@ def calculate_tx_resources( l1_gas_usage = calculate_tx_gas_usage( l2_to_l1_messages=l2_to_l1_messages, - n_modified_contracts=n_modified_contracts_by_tx, - n_storage_writes=tx_syscall_counter.get("storage_write", 0) - + FEE_TRANSFER_N_STORAGE_CHANGES_TO_CHARGE, + n_modified_contracts=n_modified_contracts, + n_storage_changes=n_storage_changes + FEE_TRANSFER_N_STORAGE_CHANGES_TO_CHARGE, l1_handler_payload_size=l1_handler_payload_size, n_deployments=n_deployments, ) cairo_usage = resources_manager.cairo_usage + tx_syscall_counter = resources_manager.syscall_counter # Add additional Cairo resources needed for the OS to run the transaction. cairo_usage += get_additional_os_resources(syscall_counter=tx_syscall_counter, tx_type=tx_type) @@ -272,11 +272,11 @@ def validate_entrypoint_execution_context(resources_manager: ExecutionResourcesM ) -def verify_no_calls_to_other_contracts(call_info: CallInfo): +def verify_no_calls_to_other_contracts(call_info: CallInfo, function_name: str): invoked_contract_address = call_info.contract_address for internal_call in call_info.gen_call_topology(): if internal_call.contract_address != invoked_contract_address: raise StarkException( code=StarknetErrorCode.UNAUTHORIZED_ACTION_ON_VALIDATE, - message="Calling other contracts during `validate` execution is forbidden.", + message=f"Calling other contracts during {function_name} execution is forbidden.", ) diff --git a/src/starkware/starknet/cli/CMakeLists.txt b/src/starkware/starknet/cli/CMakeLists.txt index 992cf233..1d90914a 100644 --- a/src/starkware/starknet/cli/CMakeLists.txt +++ b/src/starkware/starknet/cli/CMakeLists.txt @@ -7,7 +7,6 @@ python_lib(starknet_cli_lib LIBS cairo_compile_lib - cairo_tracer_lib cairo_version_lib cairo_vm_crypto_lib cairo_vm_utils_lib diff --git a/src/starkware/starknet/cli/starknet_cli.py b/src/starkware/starknet/cli/starknet_cli.py index 24d321f1..fd95d747 100755 --- a/src/starkware/starknet/cli/starknet_cli.py +++ b/src/starkware/starknet/cli/starknet_cli.py @@ -12,7 +12,6 @@ from web3 import Web3 -from services.everest.definitions import fields as everest_fields from services.external_api.client import RetryConfig from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypePointer from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager @@ -20,7 +19,6 @@ from starkware.cairo.lang.compiler.program import Program from starkware.cairo.lang.compiler.type_system import mark_type_resolved from starkware.cairo.lang.compiler.type_utils import check_felts_only_type -from starkware.cairo.lang.tracer.tracer_data import field_element_repr from starkware.cairo.lang.version import __version__ from starkware.cairo.lang.vm.crypto import get_crypto_lib_context_manager from starkware.python.utils import from_bytes @@ -54,6 +52,7 @@ AccountTransaction, Declare, Deploy, + DeployAccount, InvokeFunction, Transaction, ) @@ -121,10 +120,6 @@ def parse_block_identifiers( return block_hash, block_number -def felt_formatter(hex_felt: str) -> str: - return field_element_repr(val=int(hex_felt, 16), prime=everest_fields.FeltField.upper_bound) - - def get_optional_arg_value(args, arg_name: str, environment_var: str) -> Optional[str]: """ Returns the value of the given argument from args. If the argument was not specified, returns @@ -390,7 +385,7 @@ async def load_account( f"Unable to find wallet '{wallet}': Class '{class_name}' was not found." ) from None - return await account_class.create(starknet_context=starknet_context, account_name=account_name) + return account_class.create(starknet_context=starknet_context, account_name=account_name) def handle_network_param(args): @@ -528,27 +523,17 @@ async def create_invoke_tx_for_deploy( Creates and returns an InvokeFunction transaction to deploy a contract with the given arguments, which is wrapped and signed by the wallet provider. """ - version = constants.QUERY_VERSION if call else constants.TRANSACTION_VERSION account = await load_account_from_args(args=args) - wrapped_method, contract_address = await account.deploy_contract( + return await account.deploy_contract( class_hash=class_hash, salt=salt, constructor_calldata=constructor_calldata, deploy_from_zero=args.deploy_from_zero, chain_id=get_chain_id(args), max_fee=max_fee, - version=version, + version=constants.QUERY_VERSION if call else constants.TRANSACTION_VERSION, nonce_callback=create_get_nonce_callback(args=args), ) - tx = InvokeFunction( - contract_address=wrapped_method.address, - calldata=wrapped_method.calldata, - max_fee=wrapped_method.max_fee, - version=version, - signature=wrapped_method.signature, - nonce=wrapped_method.nonce, - ) - return tx, contract_address async def create_invoke_tx( @@ -582,7 +567,7 @@ async def create_invoke_tx( "Signature cannot be passed explicitly when using an account contract. " "Consider making a direct contract call using --no_wallet." ) - wrapped_method = await account.sign_invoke_transaction( + return await account.invoke( contract_address=invoke_tx_args.address, selector=invoke_tx_args.selector, calldata=invoke_tx_args.calldata, @@ -592,14 +577,6 @@ async def create_invoke_tx( nonce_callback=create_get_nonce_callback(args=args), dry_run=args.dry_run, ) - return InvokeFunction( - contract_address=wrapped_method.address, - calldata=wrapped_method.calldata, - max_fee=wrapped_method.max_fee, - version=version, - nonce=wrapped_method.nonce, - signature=wrapped_method.signature, - ) async def create_declare_tx( @@ -638,20 +615,31 @@ async def create_declare_tx( "Signature cannot be passed explicitly when using an account contract. " "Consider making a direct declaration using --no_wallet." ) - wrapped_method = await account.declare( + return await account.declare( contract_class=declare_tx_args.contract_class, chain_id=get_chain_id(args), max_fee=max_fee, version=version, nonce_callback=create_get_nonce_callback(args=args), ) - return Declare( - contract_class=declare_tx_args.contract_class, - sender_address=wrapped_method.address, - max_fee=wrapped_method.max_fee, + + +async def create_deploy_account_tx( + args: argparse.Namespace, + account: Account, + max_fee: int, + query: bool, +) -> Tuple[DeployAccount, int]: + """ + Creates and returns a Deploy Account transaction with the given parameters along with the new + account address. + """ + version = constants.QUERY_VERSION if query else constants.TRANSACTION_VERSION + return await account.deploy_account( + max_fee=max_fee, version=version, - signature=wrapped_method.signature, - nonce=wrapped_method.nonce, + chain_id=get_chain_id(args), + dry_run=query, ) @@ -730,10 +718,7 @@ async def simulate_or_estimate_fee(args: argparse.Namespace, tx: AccountTransact # Subparsers. -async def declare( - args: argparse.Namespace, - command_args: List[str], -): +async def declare(args: argparse.Namespace, command_args: List[str]): """ Creates a Declare transaction and sends it to the gateway. In case a wallet is provided, the transaction is wrapped and signed by the wallet provider. Otherwise, a sender address and a @@ -930,14 +915,72 @@ async def deploy_with_invoke(args: argparse.Namespace): ) -async def deploy_account(args, command_args): +async def new_account(args, command_args): + parser = argparse.ArgumentParser(description="Initializes an account contract.") + # Use parse_args to add the --help flag for the subcommand. + parser.parse_args(command_args, namespace=args) + account = await load_account_from_args(args) + account.new_account() + + +async def deploy_account(args: argparse.Namespace, command_args: List[str]): parser = argparse.ArgumentParser( - description="Initialize the account and deploy the account contract to StarkNet." + description=( + "Deploys an initialized account contract to StarkNet. " + "For more information, see new_account." + ) ) + add_deploy_account_tx_arguments(parser=parser) # Use parse_args to add the --help flag for the subcommand. parser.parse_args(command_args, namespace=args) + validate_max_fee(max_fee=args.max_fee) account = await load_account_from_args(args) - await account.deploy() + + deploy_account_tx_for_simulate, _ = await create_deploy_account_tx( + args=args, + account=account, + max_fee=args.max_fee if args.max_fee is not None else 0, + query=True, + ) + + if args.simulate or args.estimate_fee: + await simulate_or_estimate_fee(args=args, tx=deploy_account_tx_for_simulate) + return + + assert args.block_hash is None and args.block_number is None, ( + "--block_hash and --block_number should only be passed when either --simulate or " + "--estimate_fee flag are used." + ) + max_fee = await compute_max_fee( + args=args, tx=deploy_account_tx_for_simulate, is_account_contract_invocation=True + ) + + tx, contract_address = await create_deploy_account_tx( + args=args, + account=account, + max_fee=max_fee, + query=False, + ) + + gateway_client = get_gateway_client(args) + gateway_response = await gateway_client.add_transaction(tx=tx) + assert_tx_received(gateway_response=gateway_response) + # Verify the address received from the gateway. + assert (actual_address := int(gateway_response["address"], 16)) == contract_address, ( + f"The address returned from the Gateway: 0x{actual_address:064x} " + f"does not match the address stored in the account: 0x{contract_address:064x}. " + "Are you using the correct version of the CLI?" + ) + + # Don't end sentences with '.', to allow easy double-click copy-pasting of the values. + print( + f"""\ +Sent deploy account contract transaction. + +Contract address: 0x{contract_address:064x} +Transaction hash: {gateway_response['transaction_hash']} +""" + ) async def call(args: argparse.Namespace, command_args: List[str]): @@ -954,7 +997,7 @@ async def call(args: argparse.Namespace, command_args: List[str]): gateway_response = await feeder_client.call_contract( call_function=call_function_args, block_hash=args.block_hash, block_number=args.block_number ) - print(*map(felt_formatter, gateway_response["result"])) + print(*map(fields.felt_formatter, gateway_response["result"])) async def invoke(args: argparse.Namespace, command_args: List[str]): @@ -1353,9 +1396,36 @@ async def get_storage_at(args, command_args): # Add arguments. +def add_account_tx_arguments(parser: argparse.ArgumentParser): + """ + Adds the arguments: max_fee, signature and nonce. + """ + parser.add_argument( + "--nonce", + type=int, + help=( + "Used for explicitly specifying the transaction nonce. " + "If not specified, the current nonce of the account contract " + "(as returned from StarkNet) will be used." + ), + ) + parser.add_argument( + "--signature", + type=str, + nargs="*", + default=[], + help="The signature information for transaction.", + ) + parser.add_argument( + "--max_fee", + type=int, + help="The maximal fee to be paid for the execution of the transaction.", + ) + + def add_simulate_tx_arguments(parser: argparse.ArgumentParser): """ - Adds the arguments: simulate, estimate_fee. + Adds the arguments: simulate, estimate_fee and the block identifier arguments. """ parser.add_argument( "--simulate", @@ -1367,6 +1437,10 @@ def add_simulate_tx_arguments(parser: argparse.ArgumentParser): action="store_true", help="Estimates the fee of the transaction.", ) + add_block_identifier_arguments( + parser=parser, + block_role_description="be used as the context for the transaction simulation", + ) def add_declare_tx_arguments(parser: argparse.ArgumentParser): @@ -1385,35 +1459,11 @@ def add_declare_tx_arguments(parser: argparse.ArgumentParser): type=str, help="The address of the account contract sending the transaction.", ) - parser.add_argument( - "--max_fee", - type=int, - help="The maximal fee to be paid for the declaration.", - ) - parser.add_argument( - "--signature", - type=str, - nargs="*", - default=[], - help="The signature information for the declaration.", - ) - parser.add_argument( - "--nonce", - type=int, - help=( - "Used for explicitly specifying the transaction nonce. " - "If not specified, the current nonce of the account contract " - "(as returned from StarkNet) will be used." - ), - ) + add_account_tx_arguments(parser=parser) parser.add_argument( "--token", type=str, help="Used for declaring contracts in Alpha MainNet.", required=False ) add_simulate_tx_arguments(parser=parser) - add_block_identifier_arguments( - parser=parser, - block_role_description="be used as the context for the transaction simulation", - ) def add_call_function_arguments(parser: argparse.ArgumentParser): @@ -1447,39 +1497,27 @@ def add_call_l1_handler_arguments(parser: argparse.ArgumentParser): def add_invoke_tx_arguments(parser: argparse.ArgumentParser): """ - Adds the arguments: address, abi, function, inputs, nonce, signature, the simulate arguments and - the block identifier arguments. + Adds the arguments: address, abi, function, inputs, nonce, signature, max_fee, dry_run, the + simulate arguments and the block identifier arguments. """ add_call_function_arguments(parser=parser) - parser.add_argument( - "--nonce", - type=int, - help=( - "Used for explicitly specifying the transaction nonce. " - "If not specified, the current nonce of the account contract " - "(as returned from StarkNet) will be used." - ), - ) + add_account_tx_arguments(parser=parser) parser.add_argument( "--dry_run", action="store_true", help="Prepare the transaction and print it without signing or sending it.", ) + add_simulate_tx_arguments(parser=parser) + + +def add_deploy_account_tx_arguments(parser: argparse.ArgumentParser): + """ + Adds the arguments: max_fee, the simulate arguments and the block identifier arguments. + """ parser.add_argument( - "--signature", - type=str, - nargs="*", - default=[], - help="The signature information for the invoked function.", - ) - parser.add_argument( - "--max_fee", type=int, help="The maximal fee to be paid for the function invocation." + "--max_fee", type=int, help="The maximal fee to be paid for the deployment." ) add_simulate_tx_arguments(parser=parser) - add_block_identifier_arguments( - parser=parser, - block_role_description="be used as the context for the transaction simulation", - ) def add_block_identifier_arguments( @@ -1522,6 +1560,7 @@ async def main(): "get_transaction_receipt": get_transaction_receipt, "get_transaction_trace": get_transaction_trace, "invoke": invoke, + "new_account": new_account, "tx_status": tx_status, } parser = argparse.ArgumentParser(description="A tool to communicate with StarkNet.") @@ -1603,4 +1642,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) + sys.exit(asyncio.run(main())) diff --git a/src/starkware/starknet/common/constants.cairo b/src/starkware/starknet/common/constants.cairo index a77dd532..374c6c45 100644 --- a/src/starkware/starknet/common/constants.cairo +++ b/src/starkware/starknet/common/constants.cairo @@ -4,5 +4,6 @@ const ORIGIN_ADDRESS = 0; // Transaction hash prefixes. const DECLARE_HASH_PREFIX = 'declare'; const DEPLOY_HASH_PREFIX = 'deploy'; +const DEPLOY_ACCOUNT_HASH_PREFIX = 'deploy_account'; const INVOKE_HASH_PREFIX = 'invoke'; const L1_HANDLER_HASH_PREFIX = 'l1_handler'; diff --git a/src/starkware/starknet/common/eth_utils_test.py b/src/starkware/starknet/common/eth_utils_test.py index d5fe2ddc..8dd74ebd 100644 --- a/src/starkware/starknet/common/eth_utils_test.py +++ b/src/starkware/starknet/common/eth_utils_test.py @@ -38,6 +38,5 @@ def test_assert_eth_address_range(runner: CairoFunctionRunner, address, error_me "assert_eth_address_range", range_check_ptr=runner.range_check_builtin.base, address=address, + verify_implicit_args_segment=True, ) - (range_check_ptr_end,) = runner.get_return_values(1) - assert range_check_ptr_end.segment_index == runner.range_check_builtin.base.segment_index diff --git a/src/starkware/starknet/common/storage_test.py b/src/starkware/starknet/common/storage_test.py index 04553682..4b0cddf5 100644 --- a/src/starkware/starknet/common/storage_test.py +++ b/src/starkware/starknet/common/storage_test.py @@ -43,8 +43,10 @@ def test_constants(program: Program): ], ) def test_normalize_address(runner: CairoFunctionRunner, value): - runner.run("normalize_address", range_check_ptr=runner.range_check_builtin.base, addr=value) - range_check_ptr_end, result = runner.get_return_values(2) - assert range_check_ptr_end.segment_index == runner.range_check_builtin.base.segment_index - + (_, (result,)) = runner.run( + "normalize_address", + range_check_ptr=runner.range_check_builtin.base, + addr=value, + verify_implicit_args_segment=True, + ) assert result == value % ADDR_BOUND diff --git a/src/starkware/starknet/compiler/validation_utils.py b/src/starkware/starknet/compiler/validation_utils.py index 46647de6..6b48443a 100644 --- a/src/starkware/starknet/compiler/validation_utils.py +++ b/src/starkware/starknet/compiler/validation_utils.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Type, TypeVar +from typing import Dict, Iterable, List, Optional, Tuple, Type, TypeVar from starkware.cairo.lang.compiler.ast.arguments import IdentifierList from starkware.cairo.lang.compiler.ast.cairo_types import TypeTuple @@ -15,8 +15,11 @@ from starkware.starknet.definitions import constants from starkware.starknet.public.abi import ( ACCOUNT_ENTRY_POINT_NAMES, + CONSTRUCTOR_ENTRY_POINT_NAME, EXECUTE_ENTRY_POINT_NAME, + MANDATORY_ACCOUNT_ENTRY_POINT_NAMES, VALIDATE_DECLARE_ENTRY_POINT_NAME, + VALIDATE_DEPLOY_ENTRY_POINT_NAME, VALIDATE_ENTRY_POINT_NAME, AbiEntryType, AbiType, @@ -24,6 +27,13 @@ TAttr = TypeVar("TAttr") +VALIDATE_DECLARE_ARGS = [{"name": "class_hash", "type": "felt"}] +VALIDATE_DEPLOY_REQUIRED_ARGS = [ + {"name": "class_hash", "type": "felt"}, + {"name": "contract_address_salt", "type": "felt"}, +] + + # Common verifications. @@ -91,21 +101,56 @@ def verify_account_contract(contract_abi: AbiType, is_account_contract: bool): For non-account contracts, verifies that it contains none of them. """ account_entry_points: Dict[str, AbiEntryType] = {} + constructor_inputs: List[AbiEntryType] = [] # Default constructor. # Collect account contract special entry points. for entry_point in contract_abi: if entry_point["type"] == "function" and entry_point["name"] in ACCOUNT_ENTRY_POINT_NAMES: account_entry_points[entry_point["name"]] = entry_point + if ( + entry_point["type"] == "constructor" + and entry_point["name"] == CONSTRUCTOR_ENTRY_POINT_NAME + ): + # Contract has an explicit constructor. + constructor_inputs = entry_point["inputs"] + + account_entry_point_names = ( + set(account_entry_points.keys()) & MANDATORY_ACCOUNT_ENTRY_POINT_NAMES + ) + missing_account_entry_point_names = ( + MANDATORY_ACCOUNT_ENTRY_POINT_NAMES - account_entry_point_names + ) + optional_validate_deploy_entry_point = account_entry_points.get( + VALIDATE_DEPLOY_ENTRY_POINT_NAME + ) + + # Verifications. + + if optional_validate_deploy_entry_point is not None: + # Handle validate_deploy. + expected_validate_deploy_args = VALIDATE_DEPLOY_REQUIRED_ARGS + constructor_inputs + if optional_validate_deploy_entry_point["inputs"] != expected_validate_deploy_args: + message = f"""\ +Warning: +the arguments of '{VALIDATE_DEPLOY_ENTRY_POINT_NAME}' are expected to start with: +'{format_inputs(inputs=VALIDATE_DEPLOY_REQUIRED_ARGS)}' +followed by the constructor's arguments (if exist). Found: +'{format_inputs(inputs=optional_validate_deploy_entry_point['inputs'])}'. + +Deploying this contract using DeployAccount transaction is not recommended and would probably fail. +""" + print(message) - account_entry_point_names = set(account_entry_points.keys()) + # Contract-type specific Verifications. if is_account_contract: # Handle account contract. - if account_entry_point_names != ACCOUNT_ENTRY_POINT_NAMES: + if len(missing_account_entry_point_names) > 0: raise PreprocessorError( message=( - "Account contracts must have external functions named " - f"{ACCOUNT_ENTRY_POINT_NAMES}, found: {list(account_entry_point_names)}." + "Account contracts must have external functions named: " + f"{sort_and_format(names=MANDATORY_ACCOUNT_ENTRY_POINT_NAMES)}. " + f"Missing: {sort_and_format(names=missing_account_entry_point_names)}." ) ) @@ -120,11 +165,11 @@ def verify_account_contract(contract_abi: AbiType, is_account_contract: bool): ) ) - if validate_declare_entry_point["inputs"] != [{"name": "class_hash", "type": "felt"}]: + if validate_declare_entry_point["inputs"] != VALIDATE_DECLARE_ARGS: raise PreprocessorError( message=( f"'{VALIDATE_DECLARE_ENTRY_POINT_NAME}' function must have one argument " - "`class_hash: felt`." + f"'{format_inputs(inputs=VALIDATE_DECLARE_ARGS)}'." ) ) else: @@ -133,7 +178,8 @@ def verify_account_contract(contract_abi: AbiType, is_account_contract: bool): # One of the entry points exists in a non-account contract. raise PreprocessorError( message=( - f"Only account contracts may have functions named {account_entry_point_names}. " + f"Only account contracts may have functions " + f"named {sort_and_format(names=account_entry_point_names)}. " "Use the --account_contract flag to compile an account contract." ) ) @@ -194,3 +240,21 @@ def encode_calldata_arguments( has_range_check_builtin=True, identifiers=visitor.identifiers, ) + + +def format_inputs(inputs: List[Dict[str, str]]) -> str: + """ + Returns a readable string given the arguments of an entry point. + For example, + [{'name': 'arg', 'type': 'felt'}, {'name': 'argument', 'type': 'felt'}] -> + 'arg: felt, argument: felt'. + """ + return ", ".join(f"{arg['name']}: {arg['type']}" for arg in inputs) + + +def sort_and_format(names: Iterable[str]) -> str: + """ + Converts an iterable of names to a string listing the names in alphabetical order. + For example: {"Bob", "Alice", "Carol"} -> "'Alice', 'Bob', 'Carol'". + """ + return ", ".join(f"'{name}'" for name in sorted(names)) diff --git a/src/starkware/starknet/compiler/validation_utils_test.py b/src/starkware/starknet/compiler/validation_utils_test.py index d8aabed7..11369d64 100644 --- a/src/starkware/starknet/compiler/validation_utils_test.py +++ b/src/starkware/starknet/compiler/validation_utils_test.py @@ -1,102 +1,173 @@ -from typing import Dict, Iterable, List, Optional +from typing import List, Optional import pytest from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError -from starkware.starknet.compiler.validation_utils import verify_account_contract +from starkware.starknet.compiler.validation_utils import ( + VALIDATE_DECLARE_ARGS, + VALIDATE_DEPLOY_REQUIRED_ARGS, + verify_account_contract, +) from starkware.starknet.public import abi as starknet_abi -def create_mock_contract_abi( - entry_point_names: Iterable[str], - deform_entry_point_name: Optional[str] = None, - inputs: List[Dict[str, str]] = [], -) -> starknet_abi.AbiType: - mock_abi = [ +def create_account_contract_abi() -> starknet_abi.AbiType: + abi: starknet_abi.AbiType = [ { "type": "function", "name": entry_point_name, - "inputs": [{"name": "class_hash", "type": "felt"}], + "inputs": ( + VALIDATE_DECLARE_ARGS + if entry_point_name != starknet_abi.VALIDATE_DEPLOY_ENTRY_POINT_NAME + else VALIDATE_DEPLOY_REQUIRED_ARGS + ), } - for entry_point_name in entry_point_names - if entry_point_name != deform_entry_point_name + for entry_point_name in starknet_abi.ACCOUNT_ENTRY_POINT_NAMES ] - if deform_entry_point_name is not None: - mock_abi.append( - { - "type": "function", - "name": deform_entry_point_name, - "inputs": inputs, - } - ) - return mock_abi + + return abi + + +def create_faulty_account_contract_abi( + entry_point_to_remove: Optional[str] = None, + entry_point_to_corrupt: Optional[str] = None, + deformed_args: Optional[List[starknet_abi.AbiEntryType]] = None, +) -> starknet_abi.AbiType: + abi = create_account_contract_abi() + if entry_point_to_remove is not None: + abi = [entry_point for entry_point in abi if entry_point["name"] != entry_point_to_remove] + + for entry_point in abi: + if entry_point_to_corrupt is not None and entry_point["name"] == entry_point_to_corrupt: + assert deformed_args is not None + entry_point["inputs"] = deformed_args + + return abi def test_positive_flow_verify_account_contract(): # Account contract. - mock_account_contract_abi = create_mock_contract_abi( - entry_point_names=starknet_abi.ACCOUNT_ENTRY_POINT_NAMES + account_contract_abi = create_account_contract_abi() + verify_account_contract(contract_abi=account_contract_abi, is_account_contract=True) + + # Account contract without '__validate_declare__'. + account_contract_abi = create_faulty_account_contract_abi( + entry_point_to_remove=starknet_abi.VALIDATE_DEPLOY_ENTRY_POINT_NAME ) - verify_account_contract(contract_abi=mock_account_contract_abi, is_account_contract=True) + verify_account_contract(contract_abi=account_contract_abi, is_account_contract=True) # Non-account contract. - mock_account_contract_abi = create_mock_contract_abi(entry_point_names=["mock_entry_point"]) - verify_account_contract(contract_abi=mock_account_contract_abi, is_account_contract=False) + abi: starknet_abi.AbiType = [ + { + "type": "function", + "name": "increase balance", + "inputs": [{"name": "amount", "type": "felt"}], + } + ] + verify_account_contract(contract_abi=abi, is_account_contract=False) -def test_negative_flow_verify_account_contract(): +def test_negative_flow_verify_account_contract(capsys: pytest.CaptureFixture): """ Test malformed account contracts ABI. """ # Contract missing one or more of the account entry points: - # "__execute__", "__validate__", "__validate_declare__". - mock_defected_account_contract_abi = create_mock_contract_abi( - entry_point_names={ - starknet_abi.VALIDATE_ENTRY_POINT_NAME, - starknet_abi.EXECUTE_ENTRY_POINT_NAME, - } + # "__execute__", "__validate__", "__validate_declare__", "__validate_deploy__". + defected_account_contract_abi = create_faulty_account_contract_abi( + entry_point_to_remove=starknet_abi.VALIDATE_ENTRY_POINT_NAME ) with pytest.raises( - PreprocessorError, match="Account contracts must have external functions named" + PreprocessorError, + match=( + "Account contracts must have external functions " + "named: '__execute__', '__validate__', '__validate_declare__'. " + "Missing: '__validate__'." + ), ): verify_account_contract( - contract_abi=mock_defected_account_contract_abi, is_account_contract=True + contract_abi=defected_account_contract_abi, is_account_contract=True ) with pytest.raises(PreprocessorError, match="Only account contracts may have functions named"): verify_account_contract( - contract_abi=mock_defected_account_contract_abi, is_account_contract=False + contract_abi=defected_account_contract_abi, is_account_contract=False ) - # Contract where "__declare__" and "__execute__" have different calldata. - mock_defected_account_contract_abi = create_mock_contract_abi( - entry_point_names=starknet_abi.ACCOUNT_ENTRY_POINT_NAMES, - deform_entry_point_name=starknet_abi.EXECUTE_ENTRY_POINT_NAME, - inputs=[ - {"name": "class_hash", "type": "felt"}, - {"name": "contract_address", "type": "felt"}, - ], + # Contract where "__validate__" and "__execute__" have different calldata. + defected_account_contract_abi = create_faulty_account_contract_abi( + entry_point_to_corrupt=starknet_abi.EXECUTE_ENTRY_POINT_NAME, + deformed_args=[{"name": "unique_arg", "type": "felt"}], ) with pytest.raises( - PreprocessorError, match="Account contracts must have the exact same calldata for" + PreprocessorError, + match=( + "Account contracts must have the exact same calldata for '__validate__' " + "and '__execute__' functions." + ), ): verify_account_contract( - contract_abi=mock_defected_account_contract_abi, is_account_contract=True + contract_abi=defected_account_contract_abi, is_account_contract=True ) # Contract where "__validate_declare__" have malformed calldata. - mock_defected_account_contract_abi = create_mock_contract_abi( - entry_point_names=starknet_abi.ACCOUNT_ENTRY_POINT_NAMES, - deform_entry_point_name=starknet_abi.VALIDATE_DECLARE_ENTRY_POINT_NAME, - inputs=[ - {"name": "class_hash", "type": "felt"}, - {"name": "contract_address", "type": "felt"}, - ], + defected_account_contract_abi = create_faulty_account_contract_abi( + entry_point_to_corrupt=starknet_abi.VALIDATE_DECLARE_ENTRY_POINT_NAME, + deformed_args=VALIDATE_DECLARE_ARGS + [{"name": "contract_address", "type": "felt"}], ) with pytest.raises( PreprocessorError, - match=f"'{starknet_abi.VALIDATE_DECLARE_ENTRY_POINT_NAME}' function must have one argument " - "`class_hash: felt`.", + match="'__validate_declare__' function must have one argument 'class_hash: felt'.", ): verify_account_contract( - contract_abi=mock_defected_account_contract_abi, is_account_contract=True + contract_abi=defected_account_contract_abi, is_account_contract=True ) + + # Test "__validate_deploy__". + + warning_template = """\ +Warning: +the arguments of '__validate_deploy__' are expected to start with: +'class_hash: felt, contract_address_salt: felt' +followed by the constructor's arguments (if exist). Found: +'{actual_inputs}'. + +Deploying this contract using DeployAccount transaction is not recommended and would probably fail. + +""" + + # Contract where "__validate_deploy__" arguments are in the wrong order. + defected_account_contract_abi = create_faulty_account_contract_abi( + entry_point_to_corrupt=starknet_abi.VALIDATE_DEPLOY_ENTRY_POINT_NAME, + deformed_args=VALIDATE_DEPLOY_REQUIRED_ARGS[::-1], + ) + verify_account_contract(contract_abi=defected_account_contract_abi, is_account_contract=True) + captured = capsys.readouterr() + assert captured.out == warning_template.format( + actual_inputs="contract_address_salt: felt, class_hash: felt" + ) + + # Contract without a "__constructor__" and "__validate_deploy__" has additional calldata. + defected_account_contract_abi = create_faulty_account_contract_abi( + entry_point_to_corrupt=starknet_abi.VALIDATE_DEPLOY_ENTRY_POINT_NAME, + deformed_args=( + VALIDATE_DEPLOY_REQUIRED_ARGS + [{"name": "contract_address", "type": "felt"}] + ), + ) + verify_account_contract(contract_abi=defected_account_contract_abi, is_account_contract=True) + captured = capsys.readouterr() + assert captured.out == warning_template.format( + actual_inputs="class_hash: felt, contract_address_salt: felt, contract_address: felt" + ) + + # Contract where "__constructor__" is not contained in "__validate_deploy__" calldata. + defected_account_contract_abi = create_account_contract_abi() + [ + { + "type": "constructor", + "name": starknet_abi.CONSTRUCTOR_ENTRY_POINT_NAME, + "inputs": [{"name": "amount", "type": "felt"}], + } + ] + verify_account_contract(contract_abi=defected_account_contract_abi, is_account_contract=True) + captured = capsys.readouterr() + assert captured.out == warning_template.format( + actual_inputs="class_hash: felt, contract_address_salt: felt" + ) diff --git a/src/starkware/starknet/core/os/contracts.cairo b/src/starkware/starknet/core/os/contracts.cairo index c1e43885..4fa0170f 100644 --- a/src/starkware/starknet/core/os/contracts.cairo +++ b/src/starkware/starknet/core/os/contracts.cairo @@ -211,7 +211,7 @@ func load_contract_class_facts_inner{pedersen_ptr: HashBuiltin*, range_check_ptr computed_hash = ids.contract_class_fact.hash expected_hash = from_bytes(class_hash) assert computed_hash == expected_hash, ( - "Computed class_hash is inconsistent with the hash in the os_input" + "Computed class_hash is inconsistent with the hash in the os_input. " f"Computed hash = {computed_hash}, Expected hash = {expected_hash}.") vm_load_program(contract_class.program, ids.contract_class.bytecode_ptr) diff --git a/src/starkware/starknet/core/os/program_hash.json b/src/starkware/starknet/core/os/program_hash.json index c9f471db..a14aca18 100644 --- a/src/starkware/starknet/core/os/program_hash.json +++ b/src/starkware/starknet/core/os/program_hash.json @@ -1,3 +1,3 @@ { - "program_hash": "0x3695a2473fbe96955d7f2070389e0ea2d57fc0271fb1fc8391c77ef9a9cf55d" + "program_hash": "0x5e2465adccccf3e4fcbb19e4a2415690c49e94c9ce211a193ad81687ba10f59" } diff --git a/src/starkware/starknet/core/os/syscall_utils.py b/src/starkware/starknet/core/os/syscall_utils.py index 4b6375a0..fb878f5a 100644 --- a/src/starkware/starknet/core/os/syscall_utils.py +++ b/src/starkware/starknet/core/os/syscall_utils.py @@ -915,13 +915,8 @@ def _storage_write(self, address: int, value: int): # This value is needed to create the DictAccess while executing the corresponding # storage_write system call. self.starknet_storage.read(address=address) - self.starknet_storage.write(address=address, value=value) - # Update modified contracts (for the bouncer). - # Note that this is a simplified update - we are considering every write - # as a new change in storage (w.r.t. the state of the previous batch), but it could be that - # a write actually cancels a change; e.g., 0 -> 5, 5 -> 0. - self.resources_manager.modified_contracts[self.contract_address] = None + self.starknet_storage.write(address=address, value=value) def get_sequencer_address(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): return super().get_sequencer_address(segments=segments, syscall_ptr=syscall_ptr) @@ -962,16 +957,6 @@ class OsSysCallHandler(SysCallHandlerBase): The SysCallHandler implementation that is used by the gps ambassador. """ - class CallState: - def __init__(self, call_info: CallInfo): - # The CallInfo of the call. - self.call_info = call_info - # An iterator to the read_values array, used to fill the DictAccess array during the - # system call execution. - # This iterator might not be exhausted when another a nested call is entered. - # It will continue to be consumed once the inner call returns. - self.execute_syscall_read_iterator = iter(call_info.storage_read_values) - def __init__( self, tx_execution_infos: List[TransactionExecutionInfo], @@ -988,7 +973,7 @@ def __init__( # A stack that keeps track of the state of the calls being executed now. # The last item is the state of the current call; the one before it, is the # state of the caller (the call the called the current call); and so on. - self.call_stack: List[OsSysCallHandler.CallState] = [] + self.call_stack: List[CallInfo] = [] # An iterator over contract addresses that were deployed during that call. self.deployed_contracts_iterator: Iterator[int] = iter([]) @@ -1048,12 +1033,12 @@ def _deploy(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue) def _get_caller_address( self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue ) -> int: - return self.call_stack[-1].call_info.caller_address + return self.call_stack[-1].caller_address def _get_contract_address( self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue ) -> int: - return self.call_stack[-1].call_info.contract_address + return self.call_stack[-1].contract_address def _get_tx_info_ptr(self, segments: MemorySegmentManager) -> RelocatableValue: assert self.tx_info_ptr is not None @@ -1067,18 +1052,15 @@ def _storage_write(self, address: int, value: int): # in each write operation. See BusinessLogicSysCallHandler._storage_write(). next(self.execute_code_read_iterator) - def execute_syscall_storage_read(self) -> int: - """ - Advances execute_syscall_read_iterator and returns the value that was read. - """ - return next(self.call_stack[-1].execute_syscall_read_iterator) - - def execute_syscall_storage_write(self) -> int: + def execute_syscall_storage_write(self, contract_address: int, key: int, value: int) -> int: """ - Advances execute_syscall_read_iterator and returns the storage value before + Updates the cached storage and returns the storage value before the write operation. """ - return self.execute_syscall_storage_read() + previous_value = self.starknet_storage_by_address[contract_address].write( + key=key, value=value + ) + return previous_value def start_tx(self, tx_info_ptr: RelocatableValue): """ @@ -1112,7 +1094,7 @@ def enter_call(self): self.assert_interators_exhausted() call_info = next(self.call_iterator) - self.call_stack.append(self.CallState(call_info=call_info)) + self.call_stack.append(call_info) self.deployed_contracts_iterator = ( call.contract_address @@ -1124,9 +1106,7 @@ def enter_call(self): def exit_call(self): self.assert_interators_exhausted() - - call_state = self.call_stack.pop() - assert_exhausted(iterator=call_state.execute_syscall_read_iterator) + self.call_stack.pop() def skip_tx(self): """ diff --git a/src/starkware/starknet/core/os/transaction_hash/transaction_hash.py b/src/starkware/starknet/core/os/transaction_hash/transaction_hash.py index c6c1e6cc..ace7c029 100644 --- a/src/starkware/starknet/core/os/transaction_hash/transaction_hash.py +++ b/src/starkware/starknet/core/os/transaction_hash/transaction_hash.py @@ -13,6 +13,7 @@ class TransactionHashPrefix(Enum): DECLARE = from_bytes(b"declare") DEPLOY = from_bytes(b"deploy") + DEPLOY_ACCOUNT = from_bytes(b"deploy_account") INVOKE = from_bytes(b"invoke") L1_HANDLER = from_bytes(b"l1_handler") @@ -82,6 +83,30 @@ def calculate_deploy_transaction_hash( ) +def calculate_deploy_account_transaction_hash( + version: int, + contract_address: int, + class_hash: int, + constructor_calldata: Sequence[int], + max_fee: int, + nonce: int, + salt: int, + chain_id: int, + hash_function: Callable[[int, int], int] = pedersen_hash, +) -> int: + return calculate_transaction_hash_common( + tx_hash_prefix=TransactionHashPrefix.DEPLOY_ACCOUNT, + version=version, + contract_address=contract_address, + entry_point_selector=0, + calldata=[class_hash, salt, *constructor_calldata], + max_fee=max_fee, + chain_id=chain_id, + additional_data=[nonce], + hash_function=hash_function, + ) + + def calculate_declare_transaction_hash( contract_class: ContractClass, chain_id: int, diff --git a/src/starkware/starknet/core/os/transaction_hash/transaction_hash_test.py b/src/starkware/starknet/core/os/transaction_hash/transaction_hash_test.py index 75d014f7..1bd7b42c 100644 --- a/src/starkware/starknet/core/os/transaction_hash/transaction_hash_test.py +++ b/src/starkware/starknet/core/os/transaction_hash/transaction_hash_test.py @@ -9,6 +9,7 @@ from starkware.starknet.core.os.transaction_hash.transaction_hash import ( TransactionHashPrefix, calculate_declare_transaction_hash, + calculate_deploy_account_transaction_hash, calculate_deploy_transaction_hash, calculate_transaction_hash_common, compute_hash_on_elements, @@ -56,7 +57,7 @@ def run_cairo_transaction_hash( @pytest.mark.parametrize("tx_hash_prefix", set(TransactionHashPrefix)) -@pytest.mark.parametrize("calldata", [[], [659], [540, 338], [73, 443, 234, 350, 841]]) +@pytest.mark.parametrize("calldata", [[], [540, 338]]) @pytest.mark.parametrize("max_fee", [0, 10, 299]) @pytest.mark.parametrize("version", [0]) @pytest.mark.parametrize("additional_data", [[], [17]]) @@ -98,7 +99,7 @@ def test_transaction_hash_common_flow( ) -@pytest.mark.parametrize("constructor_calldata", [[], [658], [539, 337], [72, 442, 233, 349, 840]]) +@pytest.mark.parametrize("constructor_calldata", [[], [539, 337]]) def test_deploy_transaction_hash(constructor_calldata: List[int]): # Constant value unrelated to the transaction data. version = 111 @@ -131,6 +132,50 @@ def test_deploy_transaction_hash(constructor_calldata: List[int]): ) +@pytest.mark.parametrize("constructor_calldata", [[], [539, 337]]) +def test_deploy_account_transaction_hash(constructor_calldata: List[int]): + # Constant value unrelated to the transaction data. + entry_point_selector = 0 + + # Tested transaction data. + version = constants.TRANSACTION_VERSION + salt = 0 + contract_address = 19911991 + max_fee = 1 + chain_id = 2 + nonce = 0 + contract_class = get_contract_class(contract_name="dummy_account") + class_hash = compute_class_hash(contract_class=contract_class, hash_func=pedersen_hash) + calldata = [class_hash, salt, *constructor_calldata] + + expected_hash = compute_hash_on_elements( + data=[ + TransactionHashPrefix.DEPLOY_ACCOUNT.value, + version, + contract_address, + entry_point_selector, + compute_hash_on_elements(data=calldata, hash_func=pedersen_hash), + max_fee, + chain_id, + nonce, + ], + hash_func=pedersen_hash, + ) + assert ( + calculate_deploy_account_transaction_hash( + version=version, + contract_address=contract_address, + class_hash=class_hash, + constructor_calldata=constructor_calldata, + max_fee=max_fee, + nonce=nonce, + salt=salt, + chain_id=chain_id, + ) + == expected_hash + ) + + def test_declare_transaction_hash(): # Constant value unrelated to the transaction data. entry_point_selector = 0 @@ -138,8 +183,8 @@ def test_declare_transaction_hash(): # Tested transaction data. version = constants.TRANSACTION_VERSION sender_address = 19911991 - max_fee = 0 - chain_id = 1 + max_fee = 1 + chain_id = 2 nonce = 0 contract_class = get_contract_class(contract_name="dummy_account") class_hash = compute_class_hash(contract_class=contract_class, hash_func=pedersen_hash) diff --git a/src/starkware/starknet/core/os/transactions.cairo b/src/starkware/starknet/core/os/transactions.cairo index b182d693..80004185 100644 --- a/src/starkware/starknet/core/os/transactions.cairo +++ b/src/starkware/starknet/core/os/transactions.cairo @@ -1,5 +1,6 @@ from starkware.cairo.builtin_selection.select_builtins import select_builtins from starkware.cairo.builtin_selection.validate_builtins import validate_builtin, validate_builtins +from starkware.cairo.common.alloc import alloc from starkware.cairo.common.cairo_builtins import HashBuiltin from starkware.cairo.common.dict import dict_new, dict_read, dict_update, dict_write from starkware.cairo.common.dict_access import DictAccess @@ -11,6 +12,7 @@ from starkware.cairo.common.segments import relocate_segment from starkware.cairo.common.uint256 import Uint256 from starkware.starknet.common.constants import ( DECLARE_HASH_PREFIX, + DEPLOY_ACCOUNT_HASH_PREFIX, DEPLOY_HASH_PREFIX, INVOKE_HASH_PREFIX, L1_HANDLER_HASH_PREFIX, @@ -105,6 +107,10 @@ const VALIDATE_ENTRY_POINT_SELECTOR = ( const VALIDATE_DECLARE_ENTRY_POINT_SELECTOR = ( 0x289da278a8dc833409cabfdad1581e8e7d40e42dcaed693fa4008dcdb4963b3); +// get_selector_from_name('__validate_deploy__'). +const VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR = ( + 0x36fcbf06cd96843058359e1a75928beacfac10727dab22a3972f0af8aa92895); + // get_selector_from_name('transfer'). const TRANSFER_ENTRY_POINT_SELECTOR = ( 0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e); @@ -247,7 +253,7 @@ func execute_transactions_inner{ } if (tx_type == 'L1_HANDLER') { - // Handle L1 handler transaction. + // Handle the L1-handler transaction. execute_l1_handler_transaction(block_context=block_context); return execute_transactions_inner(block_context=block_context, n_txs=n_txs - 1); } @@ -258,6 +264,12 @@ func execute_transactions_inner{ return execute_transactions_inner(block_context=block_context, n_txs=n_txs - 1); } + if (tx_type == 'DEPLOY_ACCOUNT') { + // Handle the deploy-account transaction. + execute_deploy_account_transaction(block_context=block_context); + return execute_transactions_inner(block_context=block_context, n_txs=n_txs - 1); + } + assert tx_type = 'DECLARE'; // Handle the declare transaction. execute_declare_transaction(block_context=block_context); @@ -290,11 +302,11 @@ func charge_fee{ return (); } - // Transactions with fee should go through the EXECUTE_ENTRY_POINT_SELECTOR - // or VALIDATE_DECLARE_ENTRY_POINT_SELECTOR. + // Transactions with fee should go through an account contract. tempvar selector = tx_execution_context.selector; assert (selector - EXECUTE_ENTRY_POINT_SELECTOR) * - (selector - VALIDATE_DECLARE_ENTRY_POINT_SELECTOR) = 0; + (selector - VALIDATE_DECLARE_ENTRY_POINT_SELECTOR) * + (selector - VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR) = 0; local calldata: TransferCallData = TransferCallData( recipient=block_context.sequencer_address, @@ -724,8 +736,6 @@ func execute_storage_read{global_state_changes: DictAccess*}( local state_entry: StateEntry*; local new_state_entry: StateEntry*; %{ - syscall_handler.execute_syscall_storage_read() - # Fetch a state_entry in this hint and validate it in the update that comes next. ids.state_entry = __dict_manager.get_dict(ids.global_state_changes)[ids.contract_address] @@ -763,7 +773,11 @@ func execute_storage_write{global_state_changes: DictAccess*}( local state_entry: StateEntry*; local new_state_entry: StateEntry*; %{ - ids.prev_value = syscall_handler.execute_syscall_storage_write() + ids.prev_value = syscall_handler.execute_syscall_storage_write( + contract_address=ids.contract_address, + key=ids.syscall_ptr.address, + value=ids.syscall_ptr.value + ) # Fetch a state_entry in this hint and validate it in the update that comes next. ids.state_entry = __dict_manager.get_dict(ids.global_state_changes)[ids.contract_address] @@ -1200,11 +1214,7 @@ func execute_entry_point{ tempvar context = os_context; tempvar calldata_size = execution_context.calldata_size; tempvar calldata = execution_context.calldata; - %{ - vm_enter_scope({ - 'syscall_handler': syscall_handler, - }) - %} + %{ vm_enter_scope({'syscall_handler': syscall_handler}) %} call abs contract_entry_point; %{ vm_exit_scope() %} // Retrieve returned_builtin_ptrs_subset. @@ -1323,12 +1333,11 @@ func deploy_contract{ return (); } -func execute_deploy_transaction{ - range_check_ptr, - builtin_ptrs: BuiltinPointers*, - global_state_changes: DictAccess*, - outputs: OsCarriedOutputs*, -}(block_context: BlockContext*) { +// Prepares a constructor execution context based on the 'tx' hint variable. +// Leaves 'original_tx_info' empty - should be filled later on. +func prepare_constructor_execution_context{range_check_ptr, builtin_ptrs: BuiltinPointers*}() -> ( + constructor_execution_context: ExecutionContext*, salt: felt +) { alloc_locals; local contract_address_salt; @@ -1340,10 +1349,11 @@ func execute_deploy_transaction{ from starkware.python.utils import from_bytes ids.contract_address_salt = tx.contract_address_salt - ids.class_hash = from_bytes(tx.contract_hash) + ids.class_hash = from_bytes(tx.class_hash) ids.constructor_calldata_size = len(tx.constructor_calldata) ids.constructor_calldata = segments.gen_arg(arg=tx.constructor_calldata) %} + assert_nn(constructor_calldata_size); let hash_ptr = builtin_ptrs.pedersen; with hash_ptr { @@ -1363,7 +1373,7 @@ func execute_deploy_transaction{ ec_op=builtin_ptrs.ec_op, ); - local constructor_execution_context: ExecutionContext* = new ExecutionContext( + tempvar constructor_execution_context = new ExecutionContext( entry_point_type=ENTRY_POINT_TYPE_CONSTRUCTOR, caller_address=ORIGIN_ADDRESS, contract_address=contract_address, @@ -1374,6 +1384,110 @@ func execute_deploy_transaction{ original_tx_info=cast(nondet %{ segments.add() %}, TxInfo*), ); + return ( + constructor_execution_context=constructor_execution_context, salt=contract_address_salt + ); +} + +func execute_deploy_account_transaction{ + range_check_ptr, + builtin_ptrs: BuiltinPointers*, + global_state_changes: DictAccess*, + outputs: OsCarriedOutputs*, +}(block_context: BlockContext*) { + alloc_locals; + + // Calculate address and prepare constructor execution context. + let ( + local constructor_execution_context: ExecutionContext*, local salt + ) = prepare_constructor_execution_context(); + + // Prepare validate_deploy calldata. + let (validate_deploy_calldata: felt*) = alloc(); + assert validate_deploy_calldata[0] = constructor_execution_context.class_hash; + assert validate_deploy_calldata[1] = salt; + memcpy( + dst=&validate_deploy_calldata[2], + src=constructor_execution_context.calldata, + len=constructor_execution_context.calldata_size, + ); + + // Note that the members of original_tx_info are not initialized at this point. + local original_tx_info: TxInfo* = constructor_execution_context.original_tx_info; + local validate_deploy_execution_context: ExecutionContext* = new ExecutionContext( + entry_point_type=ENTRY_POINT_TYPE_EXTERNAL, + caller_address=ORIGIN_ADDRESS, + contract_address=constructor_execution_context.contract_address, + class_hash=constructor_execution_context.class_hash, + selector=VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR, + calldata_size=constructor_execution_context.calldata_size + 2, + calldata=validate_deploy_calldata, + original_tx_info=original_tx_info, + ); + + // Compute transaction hash and prepare transaction info. + let tx_version = TRANSACTION_VERSION; + local max_fee = nondet %{ tx.max_fee %}; + local nonce_ptr: felt* = cast(nondet %{ segments.gen_arg([tx.nonce]) %}, felt*); + let (transaction_hash) = compute_transaction_hash( + tx_hash_prefix=DEPLOY_ACCOUNT_HASH_PREFIX, + version=tx_version, + execution_context=validate_deploy_execution_context, + entry_point_selector_field=0, + max_fee=max_fee, + chain_id=block_context.starknet_os_config.chain_id, + additional_data_size=1, + additional_data=nonce_ptr, + ); + + // Assign the transaction info to both calls. + // Note that both constructor_execution_context and + // validate_deploy_execution_context hold this pointer. + assert [original_tx_info] = TxInfo( + version=tx_version, + account_contract_address=validate_deploy_execution_context.contract_address, + max_fee=max_fee, + signature_len=nondet %{ len(tx.signature) %}, + signature=cast(nondet %{ segments.gen_arg(arg=tx.signature) %}, felt*), + transaction_hash=transaction_hash, + chain_id=block_context.starknet_os_config.chain_id, + nonce=[nonce_ptr], + ); + + %{ syscall_handler.start_tx(tx_info_ptr=ids.original_tx_info.address_) %} + + deploy_contract( + block_context=block_context, constructor_execution_context=constructor_execution_context + ); + + // Handle nonce here since 'deploy_contract' verifies that the nonce is zeroed. + check_and_increment_nonce( + execution_context=validate_deploy_execution_context, nonce=[nonce_ptr] + ); + + // Runs the account contract's "__validate_deploy__" entry point, + // which is responsible for signature verification. + execute_entry_point( + block_context=block_context, execution_context=validate_deploy_execution_context + ); + charge_fee(block_context=block_context, tx_execution_context=validate_deploy_execution_context); + + %{ syscall_handler.end_tx() %} + return (); +} + +func execute_deploy_transaction{ + range_check_ptr, + builtin_ptrs: BuiltinPointers*, + global_state_changes: DictAccess*, + outputs: OsCarriedOutputs*, +}(block_context: BlockContext*) { + alloc_locals; + + let ( + local constructor_execution_context: ExecutionContext*, _ + ) = prepare_constructor_execution_context(); + // Guess tx version and make sure it's valid. local tx_version = nondet %{ tx.version %}; validate_transaction_version(tx_version=tx_version); diff --git a/src/starkware/starknet/core/test_contract/dummy_account.cairo b/src/starkware/starknet/core/test_contract/dummy_account.cairo index 9e19b13c..893a21a5 100644 --- a/src/starkware/starknet/core/test_contract/dummy_account.cairo +++ b/src/starkware/starknet/core/test_contract/dummy_account.cairo @@ -24,6 +24,11 @@ func __validate_declare__(class_hash: felt) { return (); } +@external +func __validate_deploy__(class_hash: felt, contract_address_salt: felt) { + return (); +} + @external func __validate__(contract_address, selector: felt, calldata_len: felt, calldata: felt*) { return (); diff --git a/src/starkware/starknet/definitions/CMakeLists.txt b/src/starkware/starknet/definitions/CMakeLists.txt index 873f39f9..68f8046e 100644 --- a/src/starkware/starknet/definitions/CMakeLists.txt +++ b/src/starkware/starknet/definitions/CMakeLists.txt @@ -8,6 +8,7 @@ python_lib(starknet_definitions_lib transaction_type.py LIBS + cairo_tracer_lib cairo_vm_crypto_lib everest_definitions_lib everest_transaction_type_lib @@ -28,6 +29,7 @@ python_lib(starknet_general_config_lib general_config.yml LIBS + cairo_all_builtins_lib cairo_instances_lib cairo_run_builtins_lib everest_general_config_lib diff --git a/src/starkware/starknet/definitions/error_codes.py b/src/starkware/starknet/definitions/error_codes.py index 3388c2d6..a9b7c1e1 100644 --- a/src/starkware/starknet/definitions/error_codes.py +++ b/src/starkware/starknet/definitions/error_codes.py @@ -9,6 +9,7 @@ class StarknetErrorCode(ErrorCode): CONTRACT_ADDRESS_UNAVAILABLE = auto() CONTRACT_BYTECODE_SIZE_TOO_LARGE = auto() CONTRACT_CLASS_OBJECT_SIZE_TOO_LARGE = auto() + DEPRECATED_TRANSACTION = auto() ENTRY_POINT_NOT_FOUND_IN_CONTRACT = auto() EXTERNAL_TO_INTERNAL_CONVERSION_ERROR = auto() FEE_TRANSFER_FAILURE = auto() diff --git a/src/starkware/starknet/definitions/fields.py b/src/starkware/starknet/definitions/fields.py index 18b20aa5..9e211527 100644 --- a/src/starkware/starknet/definitions/fields.py +++ b/src/starkware/starknet/definitions/fields.py @@ -6,9 +6,11 @@ import marshmallow.utils from services.everest.definitions import fields as everest_fields +from starkware.cairo.lang.tracer.tracer_data import field_element_repr from starkware.python.utils import from_bytes from starkware.starknet.definitions import constants from starkware.starknet.definitions.error_codes import StarknetErrorCode +from starkware.starknet.definitions.transaction_type import TransactionType from starkware.starkware_utils.field_validators import ( validate_length, validate_non_negative, @@ -16,6 +18,7 @@ ) from starkware.starkware_utils.marshmallow_dataclass_fields import ( BytesAsHex, + EnumField, FrozenDictField, IntAsHex, IntAsStr, @@ -45,6 +48,14 @@ ) +def felt_formatter(hex_felt: str) -> str: + return field_element_repr(val=int(hex_felt, 16), prime=everest_fields.FeltField.upper_bound) + + +def felt_formatter_from_int(int_felt: int) -> str: + return field_element_repr(val=int_felt, prime=everest_fields.FeltField.upper_bound) + + def bytes_as_hex_dict_keys_metadata( values_schema: Type[marshmallow.Schema], ) -> Dict[str, mfields.Dict]: @@ -284,7 +295,9 @@ def class_hash_from_bytes(class_hash: bytes) -> str: error_code=StarknetErrorCode.OUT_OF_RANGE_TRANSACTION_VERSION, formatter=hex, ) -tx_version_metadata = TransactionVersionField.metadata(required=False, load_default=0) +non_required_tx_version_metadata = TransactionVersionField.metadata(required=False, load_default=0) + +tx_version_metadata = TransactionVersionField.metadata() # State root. @@ -364,3 +377,9 @@ def class_hash_from_bytes(class_hash: bytes) -> str: load_default=dict, ) ) + +optional_tx_type_metadata = dict( + marshmallow_field=EnumField( + enum_cls=TransactionType, required=False, load_default=None, allow_none=True + ) +) diff --git a/src/starkware/starknet/definitions/general_config.yml b/src/starkware/starknet/definitions/general_config.yml index 4f282530..a7d5eacf 100644 --- a/src/starkware/starknet/definitions/general_config.yml +++ b/src/starkware/starknet/definitions/general_config.yml @@ -5,10 +5,10 @@ event_commitment_tree_height: 64 global_state_commitment_tree_height: 251 invoke_tx_max_n_steps: 1000000 min_gas_price: 100000000000 -sequencer_address: '0x6c4ae5be723ab0402cd675e7f8e1cf5a775e972e6f080ad0b842f40ec202a69' +sequencer_address: '0x388ca486b82e20cc81965d056b4cdcaacdffe0cf08e20ed8ba10ea97a487004' starknet_os_config: chain_id: TESTNET - fee_token_address: '0x5818bca45def7134510ce9874b5061ed3f8048074613817b03cd849907980d' + fee_token_address: '0x25c725399cf6de6baa0be8f1adbd93c11d34424e47e7e73f01f6557e5667d92' tx_commitment_tree_height: 64 tx_version: 1 validate_max_n_steps: 1000000 diff --git a/src/starkware/starknet/definitions/transaction_type.py b/src/starkware/starknet/definitions/transaction_type.py index 5ba8f256..0252966e 100644 --- a/src/starkware/starknet/definitions/transaction_type.py +++ b/src/starkware/starknet/definitions/transaction_type.py @@ -6,6 +6,7 @@ class TransactionType(TransactionTypeBase): DECLARE = 0 DEPLOY = auto() + DEPLOY_ACCOUNT = auto() INITIALIZE_BLOCK_INFO = auto() INVOKE_FUNCTION = auto() L1_HANDLER = auto() diff --git a/src/starkware/starknet/public/abi.py b/src/starkware/starknet/public/abi.py index 4aefb956..2a68b492 100644 --- a/src/starkware/starknet/public/abi.py +++ b/src/starkware/starknet/public/abi.py @@ -23,11 +23,13 @@ TRANSFER_ENTRY_POINT_NAME = "transfer" VALIDATE_ENTRY_POINT_NAME = "__validate__" VALIDATE_DECLARE_ENTRY_POINT_NAME = "__validate_declare__" -ACCOUNT_ENTRY_POINT_NAMES = { +VALIDATE_DEPLOY_ENTRY_POINT_NAME = "__validate_deploy__" +MANDATORY_ACCOUNT_ENTRY_POINT_NAMES = { EXECUTE_ENTRY_POINT_NAME, VALIDATE_ENTRY_POINT_NAME, VALIDATE_DECLARE_ENTRY_POINT_NAME, } +ACCOUNT_ENTRY_POINT_NAMES = {*MANDATORY_ACCOUNT_ENTRY_POINT_NAMES, VALIDATE_DEPLOY_ENTRY_POINT_NAME} AbiEntryType = Dict[str, Any] AbiType = List[AbiEntryType] @@ -55,6 +57,9 @@ def get_selector_from_name(func_name: str) -> int: VALIDATE_DECLARE_ENTRY_POINT_SELECTOR = get_selector_from_name( func_name=VALIDATE_DECLARE_ENTRY_POINT_NAME ) +VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR = get_selector_from_name( + func_name=VALIDATE_DEPLOY_ENTRY_POINT_NAME +) def get_storage_var_address(var_name: str, *args) -> int: diff --git a/src/starkware/starknet/security/CMakeLists.txt b/src/starkware/starknet/security/CMakeLists.txt index 5b13b9ef..de1fdba1 100644 --- a/src/starkware/starknet/security/CMakeLists.txt +++ b/src/starkware/starknet/security/CMakeLists.txt @@ -50,6 +50,7 @@ python_lib(starknet_hints_whitelist_lib whitelists/cairo_keccak.json whitelists/cairo_secp.json whitelists/cairo_sha256.json + whitelists/cairo_sha256_arbitrary_input_length.json whitelists/ec_bigint.json whitelists/ec_recover.json whitelists/latest.json diff --git a/src/starkware/starknet/security/starknet_common.cairo b/src/starkware/starknet/security/starknet_common.cairo index 7d1ac65a..6a2d374b 100644 --- a/src/starkware/starknet/security/starknet_common.cairo +++ b/src/starkware/starknet/security/starknet_common.cairo @@ -68,6 +68,7 @@ from starkware.cairo.common.uint256 import ( uint256_le, uint256_lt, uint256_mul, + uint256_mul_div_mod, uint256_neg, uint256_not, uint256_or, diff --git a/src/starkware/starknet/security/whitelists/cairo_sha256_arbitrary_input_length.json b/src/starkware/starknet/security/whitelists/cairo_sha256_arbitrary_input_length.json new file mode 100644 index 00000000..c643a775 --- /dev/null +++ b/src/starkware/starknet/security/whitelists/cairo_sha256_arbitrary_input_length.json @@ -0,0 +1,20 @@ +{ + "allowed_reference_expressions_for_hint": [ + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_sha256.sha256_utils import (", + " compute_message_schedule, sha2_compress_function)", + "", + "_sha256_input_chunk_size_felts = int(ids.SHA256_INPUT_CHUNK_SIZE_FELTS)", + "assert 0 <= _sha256_input_chunk_size_felts < 100", + "_sha256_state_size_felts = int(ids.SHA256_STATE_SIZE_FELTS)", + "assert 0 <= _sha256_state_size_felts < 100", + "w = compute_message_schedule(memory.get_range(", + " ids.sha256_start, _sha256_input_chunk_size_felts))", + "new_state = sha2_compress_function(memory.get_range(ids.state, _sha256_state_size_felts), w)", + "segments.write_arg(ids.output, new_state)" + ] + } + ] +} diff --git a/src/starkware/starknet/security/whitelists/latest.json b/src/starkware/starknet/security/whitelists/latest.json index b5bd8bb2..2fd1a735 100644 --- a/src/starkware/starknet/security/whitelists/latest.json +++ b/src/starkware/starknet/security/whitelists/latest.json @@ -59,6 +59,22 @@ "ids.is_small = 1 if ids.addr < ADDR_BOUND else 0" ] }, + { + "allowed_expressions": [], + "hint_lines": [ + "a = (ids.a.high << 128) + ids.a.low", + "b = (ids.b.high << 128) + ids.b.low", + "div = (ids.div.high << 128) + ids.div.low", + "quotient, remainder = divmod(a * b, div)", + "", + "ids.quotient_low.low = quotient & ((1 << 128) - 1)", + "ids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)", + "ids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)", + "ids.quotient_high.high = quotient >> 384", + "ids.remainder.low = remainder & ((1 << 128) - 1)", + "ids.remainder.high = remainder >> 128" + ] + }, { "allowed_expressions": [], "hint_lines": [ diff --git a/src/starkware/starknet/services/api/CMakeLists.txt b/src/starkware/starknet/services/api/CMakeLists.txt index 07e60710..dda07068 100644 --- a/src/starkware/starknet/services/api/CMakeLists.txt +++ b/src/starkware/starknet/services/api/CMakeLists.txt @@ -14,6 +14,7 @@ python_lib(starknet_messages_lib starknet_transaction_objects_lib starkware_dataclasses_utils_lib starkware_error_handling_lib + starkware_python_utils_lib ) python_lib(starknet_contract_class_lib diff --git a/src/starkware/starknet/services/api/feeder_gateway/response_objects.py b/src/starkware/starknet/services/api/feeder_gateway/response_objects.py index ddf4b67c..b382ff89 100644 --- a/src/starkware/starknet/services/api/feeder_gateway/response_objects.py +++ b/src/starkware/starknet/services/api/feeder_gateway/response_objects.py @@ -1,4 +1,5 @@ import dataclasses +from abc import abstractmethod from dataclasses import field from enum import Enum, auto from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union @@ -8,6 +9,7 @@ import marshmallow.fields as mfields import marshmallow.utils import marshmallow_dataclass +from marshmallow.decorators import pre_load from marshmallow_oneofschema import OneOfSchema from typing_extensions import Literal from web3 import Web3 @@ -30,6 +32,7 @@ from starkware.starknet.business_logic.transaction.objects import ( InternalDeclare, InternalDeploy, + InternalDeployAccount, InternalInvokeFunction, InternalL1Handler, InternalTransaction, @@ -224,8 +227,9 @@ def __post_init__(self): @marshmallow_dataclass.dataclass(frozen=True) class TransactionSpecificInfo(ValidatedResponseObject): + transaction_hash: int = field(metadata=fields.transaction_hash_metadata) tx_type: ClassVar[TransactionType] - version: int = field(metadata=fields.tx_version_metadata) + version: int = field(metadata=fields.non_required_tx_version_metadata) @classmethod def from_internal(cls, internal_tx: InternalTransaction) -> "TransactionSpecificInfo": @@ -233,9 +237,14 @@ def from_internal(cls, internal_tx: InternalTransaction) -> "TransactionSpecific return DeclareSpecificInfo.from_internal_declare(internal_tx=internal_tx) elif isinstance(internal_tx, InternalDeploy): return DeploySpecificInfo.from_internal_deploy(internal_tx=internal_tx) + elif isinstance(internal_tx, InternalDeployAccount): + return DeployAccountSpecificInfo.from_internal_deploy_account(internal_tx=internal_tx) elif isinstance(internal_tx, InternalInvokeFunction): if internal_tx.entry_point_type is EntryPointType.L1_HANDLER: return L1HandlerSpecificInfo.from_internal_invoke(internal_tx=internal_tx) + assert ( + internal_tx.entry_point_type is EntryPointType.EXTERNAL + ), "An InternalInvokeFunction transaction must have EXTERNAL entry point type." return InvokeSpecificInfo.from_internal_invoke(internal_tx=internal_tx) elif isinstance(internal_tx, InternalL1Handler): return L1HandlerSpecificInfo.from_internal_l1_handler(internal_tx=internal_tx) @@ -243,17 +252,35 @@ def from_internal(cls, internal_tx: InternalTransaction) -> "TransactionSpecific raise NotImplementedError(f"No response object for {internal_tx}.") +# Mypy has a problem with dataclasses that contain unimplemented abstract methods. +# See https://github.com/python/mypy/issues/5374 for details on this problem. +@marshmallow_dataclass.dataclass(frozen=True) # type: ignore[misc] +class AccountTransactionSpecificInfo(TransactionSpecificInfo): + max_fee: int = field(metadata=fields.fee_metadata) + signature: List[int] = field(metadata=fields.signature_as_hex_metadata) + nonce: Optional[int] = field(metadata=fields.optional_nonce_metadata) + + @property + @abstractmethod + def account_contract_address(self) -> int: + """ + The address of the account contract initiating this transaction. + """ + + @marshmallow_dataclass.dataclass(frozen=True) -class DeclareSpecificInfo(TransactionSpecificInfo): +class DeclareSpecificInfo(AccountTransactionSpecificInfo): class_hash: int = field(metadata=fields.ClassHashIntField.metadata()) sender_address: int = field(metadata=fields.contract_address_metadata) + # Repeat `nonce` to narrow its type to non-optional int. nonce: int = field(metadata=fields.nonce_metadata) - max_fee: int = field(metadata=fields.fee_metadata) - transaction_hash: int = field(metadata=fields.transaction_hash_metadata) - signature: List[int] = field(metadata=fields.signature_as_hex_metadata) tx_type: ClassVar[TransactionType] = TransactionType.DECLARE + @property + def account_contract_address(self) -> int: + return self.sender_address + @classmethod def from_internal_declare(cls, internal_tx: InternalDeclare) -> "DeclareSpecificInfo": return cls( @@ -273,7 +300,6 @@ class DeploySpecificInfo(TransactionSpecificInfo): contract_address_salt: int = field(metadata=fields.contract_address_salt_metadata) class_hash: Optional[int] = field(metadata=fields.OptionalClassHashIntField.metadata()) constructor_calldata: List[int] = field(metadata=fields.call_data_as_hex_metadata) - transaction_hash: int = field(metadata=fields.transaction_hash_metadata) tx_type: ClassVar[TransactionType] = TransactionType.DEPLOY @@ -290,24 +316,71 @@ def from_internal_deploy(cls, internal_tx: InternalDeploy) -> "DeploySpecificInf @marshmallow_dataclass.dataclass(frozen=True) -class InvokeSpecificInfo(TransactionSpecificInfo): +class DeployAccountSpecificInfo(AccountTransactionSpecificInfo): contract_address: int = field(metadata=fields.contract_address_metadata) - entry_point_selector: int = field(metadata=fields.entry_point_selector_metadata) - entry_point_type: EntryPointType - nonce: Optional[int] = field(metadata=fields.optional_nonce_metadata) + contract_address_salt: int = field(metadata=fields.contract_address_salt_metadata) + class_hash: int = field(metadata=fields.ClassHashIntField.metadata()) + constructor_calldata: List[int] = field(metadata=fields.call_data_as_hex_metadata) + version: int = field(metadata=fields.tx_version_metadata) + # Repeat `nonce` to narrow its type to non-optional int. + nonce: int = field(metadata=fields.nonce_metadata) + + tx_type: ClassVar[TransactionType] = TransactionType.DEPLOY_ACCOUNT + + @property + def account_contract_address(self) -> int: + return self.contract_address + + @classmethod + def from_internal_deploy_account( + cls, internal_tx: InternalDeployAccount + ) -> "DeployAccountSpecificInfo": + return cls( + contract_address=internal_tx.contract_address, + contract_address_salt=internal_tx.contract_address_salt, + class_hash=from_bytes(internal_tx.class_hash), + constructor_calldata=internal_tx.constructor_calldata, + nonce=internal_tx.nonce, + max_fee=internal_tx.max_fee, + version=internal_tx.version, + transaction_hash=internal_tx.hash_value, + signature=internal_tx.signature, + ) + + +@marshmallow_dataclass.dataclass(frozen=True) +class InvokeSpecificInfo(AccountTransactionSpecificInfo): + contract_address: int = field(metadata=fields.contract_address_metadata) + entry_point_selector: Optional[int] = field( + metadata=fields.optional_entry_point_selector_metadata + ) calldata: List[int] = field(metadata=fields.call_data_as_hex_metadata) - signature: List[int] = field(metadata=fields.signature_as_hex_metadata) - transaction_hash: int = field(metadata=fields.transaction_hash_metadata) - max_fee: int = field(metadata=fields.fee_metadata) tx_type: ClassVar[TransactionType] = TransactionType.INVOKE_FUNCTION + @property + def account_contract_address(self) -> int: + return self.contract_address + + @pre_load + def remove_entry_point_type_and_make_selector_optional( + self, data: Dict[str, Any], many: bool, **kwargs + ) -> Dict[str, List[str]]: + if "entry_point_type" in data: + del data["entry_point_type"] + + version = fields.TransactionVersionField.load_value(data["version"]) + if version != 0: + data["entry_point_selector"] = None + return data + @classmethod def from_internal_invoke(cls, internal_tx: InternalInvokeFunction) -> "InvokeSpecificInfo": return cls( contract_address=internal_tx.contract_address, - entry_point_selector=internal_tx.entry_point_selector, - entry_point_type=internal_tx.entry_point_type, + entry_point_selector=( + None if internal_tx.version != 0 else internal_tx.entry_point_selector + ), nonce=internal_tx.nonce, calldata=internal_tx.calldata, version=internal_tx.version, @@ -321,9 +394,8 @@ def from_internal_invoke(cls, internal_tx: InternalInvokeFunction) -> "InvokeSpe class L1HandlerSpecificInfo(TransactionSpecificInfo): contract_address: int = field(metadata=fields.contract_address_metadata) entry_point_selector: int = field(metadata=fields.entry_point_selector_metadata) - nonce: int = field(metadata=fields.nonce_metadata) + nonce: Optional[int] = field(metadata=fields.optional_nonce_metadata) calldata: List[int] = field(metadata=fields.call_data_as_hex_metadata) - transaction_hash: int = field(metadata=fields.transaction_hash_metadata) tx_type: ClassVar[TransactionType] = TransactionType.L1_HANDLER @@ -346,7 +418,7 @@ def from_internal_invoke(cls, internal_tx: InternalInvokeFunction) -> "L1Handler return cls( contract_address=internal_tx.contract_address, entry_point_selector=internal_tx.entry_point_selector, - nonce=internal_tx.nonce if internal_tx.nonce is not None else 0, + nonce=internal_tx.nonce, calldata=internal_tx.calldata, version=constants.L1_HANDLER_VERSION, transaction_hash=internal_tx.hash_value, @@ -357,6 +429,7 @@ class TransactionSpecificInfoSchema(OneOfSchema): type_schemas: Dict[str, Type[marshmallow.Schema]] = { TransactionType.DECLARE.name: DeclareSpecificInfo.Schema, TransactionType.DEPLOY.name: DeploySpecificInfo.Schema, + TransactionType.DEPLOY_ACCOUNT.name: DeployAccountSpecificInfo.Schema, TransactionType.INVOKE_FUNCTION.name: InvokeSpecificInfo.Schema, TransactionType.L1_HANDLER.name: L1HandlerSpecificInfo.Schema, } diff --git a/src/starkware/starknet/services/api/gateway/transaction.py b/src/starkware/starknet/services/api/gateway/transaction.py index de0be724..9f59eae4 100644 --- a/src/starkware/starknet/services/api/gateway/transaction.py +++ b/src/starkware/starknet/services/api/gateway/transaction.py @@ -9,10 +9,14 @@ from marshmallow_oneofschema import OneOfSchema from services.everest.api.gateway.transaction import EverestTransaction -from starkware.starknet.core.os.contract_address.contract_address import calculate_contract_address +from starkware.starknet.core.os.contract_address.contract_address import ( + calculate_contract_address, + calculate_contract_address_from_hash, +) from starkware.starknet.core.os.transaction_hash.transaction_hash import ( TransactionHashPrefix, calculate_declare_transaction_hash, + calculate_deploy_account_transaction_hash, calculate_deploy_transaction_hash, calculate_transaction_hash_common, ) @@ -42,7 +46,7 @@ class Transaction(EverestTransaction): # signed by the account contract. # This field allows invalidating old transactions, whenever the meaning of the other # transaction fields is changed (in the OS). - version: int = field(metadata=fields.tx_version_metadata) + version: int = field(metadata=fields.non_required_tx_version_metadata) @property @classmethod @@ -92,6 +96,7 @@ class Declare(AccountTransaction): contract_class: ContractClass # The address of the account contract sending the declaration transaction. sender_address: int = field(metadata=fields.contract_address_metadata) + # Repeat `nonce` to narrow its type to non-optional int. nonce: int = field(metadata=fields.nonce_metadata) # Class variables. @@ -170,6 +175,45 @@ def calculate_hash(self, general_config: StarknetGeneralConfig) -> int: ) +@marshmallow_dataclass.dataclass(frozen=True) +class DeployAccount(AccountTransaction): + """ + Represents a transaction in the StarkNet network that is a deployment of a StarkNet account + contract. + """ + + class_hash: int = field(metadata=fields.ClassHashIntField.metadata()) + contract_address_salt: int = field(metadata=fields.contract_address_salt_metadata) + constructor_calldata: List[int] = field(metadata=fields.call_data_metadata) + version: int = field(metadata=fields.tx_version_metadata) + # Repeat `nonce` to narrow its type to non-optional int. + nonce: int = field(metadata=fields.nonce_metadata) + + # Class variables. + tx_type: ClassVar[TransactionType] = TransactionType.DEPLOY_ACCOUNT + + def calculate_hash(self, general_config: StarknetGeneralConfig) -> int: + """ + Calculates the transaction hash in the StarkNet network. + """ + contract_address = calculate_contract_address_from_hash( + salt=self.contract_address_salt, + class_hash=self.class_hash, + constructor_calldata=self.constructor_calldata, + deployer_address=0, + ) + return calculate_deploy_account_transaction_hash( + version=self.version, + contract_address=contract_address, + class_hash=self.class_hash, + constructor_calldata=self.constructor_calldata, + max_fee=self.max_fee, + nonce=self.nonce, + salt=self.contract_address_salt, + chain_id=general_config.chain_id.value, + ) + + @marshmallow_dataclass.dataclass(frozen=True) class InvokeFunction(AccountTransaction): """ @@ -246,6 +290,7 @@ class AccountTransactionSchema(OneOfSchema): type_schemas: Dict[str, Type[marshmallow.Schema]] = { TransactionType.DECLARE.name: Declare.Schema, + TransactionType.DEPLOY_ACCOUNT.name: DeployAccount.Schema, TransactionType.INVOKE_FUNCTION.name: InvokeFunction.Schema, } diff --git a/src/starkware/starknet/services/api/messages.py b/src/starkware/starknet/services/api/messages.py index fa39ee28..e7ef5b37 100644 --- a/src/starkware/starknet/services/api/messages.py +++ b/src/starkware/starknet/services/api/messages.py @@ -5,6 +5,7 @@ from services.everest.definitions import fields as everest_fields from starkware.cairo.bootloaders.compute_fact import keccak_ints +from starkware.python.utils import as_non_optional from starkware.starknet.business_logic.transaction.objects import InternalL1Handler from starkware.starknet.definitions import fields from starkware.starkware_utils.validated_dataclass import ValidatedDataclass @@ -75,5 +76,5 @@ def get_message_hash_from_tx(tx: InternalL1Handler) -> str: to_address=tx.contract_address, l1_handler_selector=tx.entry_point_selector, payload=tx.calldata[1:], - nonce=tx.nonce, + nonce=as_non_optional(tx.nonce), ).get_hash() diff --git a/src/starkware/starknet/services/utils/sequencer_api_utils.py b/src/starkware/starknet/services/utils/sequencer_api_utils.py index 546ebaba..9c9285c6 100644 --- a/src/starkware/starknet/services/utils/sequencer_api_utils.py +++ b/src/starkware/starknet/services/utils/sequencer_api_utils.py @@ -7,6 +7,7 @@ from starkware.starknet.business_logic.transaction.objects import ( InternalAccountTransaction, InternalDeclare, + InternalDeployAccount, InternalInvokeFunction, InternalTransaction, ) @@ -16,6 +17,7 @@ from starkware.starknet.services.api.gateway.transaction import ( AccountTransaction, Declare, + DeployAccount, InvokeFunction, ) from starkware.starkware_utils.config_base import Config @@ -48,6 +50,8 @@ def from_external( internal_cls = InternalInvokeFunctionForSimulate elif isinstance(external_tx, Declare): internal_cls = InternalDeclareForSimulate + elif isinstance(external_tx, DeployAccount): + internal_cls = InternalDeployAccountForSimulate else: raise NotImplementedError(f"Unexpected type {type(external_tx).__name__}.") @@ -56,7 +60,7 @@ def from_external( ) def verify_version(self): - verify_version(version=self.version, only_query=True) + verify_version(version=self.version, only_query=True, old_supported_versions=[0]) def charge_fee( self, state: SyncState, resources: ResourcesMapping, general_config: StarknetGeneralConfig @@ -83,3 +87,14 @@ class InternalDeclareForSimulate(InternalAccountTransactionForSimulate, Internal """ Represents an internal declare in the StarkNet network for the simulate transaction API. """ + + +class InternalDeployAccountForSimulate( + InternalAccountTransactionForSimulate, InternalDeployAccount +): + """ + Represents an internal deploy account in the StarkNet network for the simulate transaction API. + """ + + def verify_version(self): + verify_version(version=self.version, only_query=True, old_supported_versions=[]) diff --git a/src/starkware/starknet/storage/starknet_storage.py b/src/starkware/starknet/storage/starknet_storage.py index ae2fd16f..88326635 100644 --- a/src/starkware/starknet/storage/starknet_storage.py +++ b/src/starkware/starknet/storage/starknet_storage.py @@ -327,6 +327,7 @@ def __init__( commitment_tree: PatriciaTree, updated_commitment_tree: PatriciaTree, commitment_tree_facts: BinaryFactDict, + ongoing_storage_changes: Dict[int, int], ): """ The constructor is private. @@ -337,6 +338,7 @@ def __init__( # entering the CairoRunner run) for optimization. self.updated_commitment_tree = updated_commitment_tree self.commitment_tree_facts = commitment_tree_facts + self.ongoing_storage_changes = ongoing_storage_changes def commitment_update(self) -> Tuple[PatriciaTree, BinaryFactDict]: return self.updated_commitment_tree, self.commitment_tree_facts @@ -365,8 +367,24 @@ async def create( actual_updated_commitment_tree == updated_commitment_tree ), "Inconsistent commitment tree roots." + # Fetch initial values of keys accessed by this contract. + initial_leaves = await previous_commitment_tree.get_leaves( + ffc=ffc, indices=accessed_addresses, fact_cls=StorageLeaf + ) + initial_entries = {key: leaf.value for key, leaf in initial_leaves.items()} return cls( commitment_tree=previous_commitment_tree, updated_commitment_tree=updated_commitment_tree, commitment_tree_facts=commitment_tree_facts, + ongoing_storage_changes=initial_entries, ) + + def write(self, key: int, value: int) -> int: + """ + Writes the given value in the given key in ongoing_storage_changes and returns the + previous value. This value is needed to create the DictAccess while executing the + corresponding storage_write system call. + """ + previous_value = self.ongoing_storage_changes[key] + self.ongoing_storage_changes[key] = value + return previous_value diff --git a/src/starkware/starknet/testing/MockStarknetMessaging.sol b/src/starkware/starknet/testing/MockStarknetMessaging.sol index fa4b0baf..e4e4afc7 100644 --- a/src/starkware/starknet/testing/MockStarknetMessaging.sol +++ b/src/starkware/starknet/testing/MockStarknetMessaging.sol @@ -37,6 +37,6 @@ contract MockStarknetMessaging is StarknetMessaging { ); require(l1ToL2Messages()[msgHash] > 0, "INVALID_MESSAGE_TO_CONSUME"); - l1ToL2Messages()[msgHash] -= 1; + l1ToL2Messages()[msgHash] = 0; } } diff --git a/src/starkware/starknet/testing/starknet.py b/src/starkware/starknet/testing/starknet.py index 806e4a0b..73ac9052 100644 --- a/src/starkware/starknet/testing/starknet.py +++ b/src/starkware/starknet/testing/starknet.py @@ -42,13 +42,17 @@ async def declare( source: Optional[str] = None, contract_class: Optional[ContractClass] = None, cairo_path: Optional[List[str]] = None, + disable_hint_validation: bool = False, ) -> DeclaredClass: """ Declares a ContractClass in the StarkNet network. Returns the class hash and the ABI of the contract. """ contract_class = get_contract_class( - source=source, contract_class=contract_class, cairo_path=cairo_path + source=source, + contract_class=contract_class, + cairo_path=cairo_path, + disable_hint_validation=disable_hint_validation, ) class_hash, _ = await self.state.declare(contract_class=contract_class) assert class_hash is not None diff --git a/src/starkware/starknet/testing/starknet_test.py b/src/starkware/starknet/testing/starknet_test.py index 5ed4f4d4..a6ac6887 100644 --- a/src/starkware/starknet/testing/starknet_test.py +++ b/src/starkware/starknet/testing/starknet_test.py @@ -4,6 +4,7 @@ import pytest import pytest_asyncio +from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.starknet.compiler.compile import compile_starknet_files from starkware.starknet.testing.contract import StarknetContract from starkware.starknet.testing.starknet import Starknet @@ -107,9 +108,29 @@ async def test_struct_arrays(starknet: Starknet): await contract.transpose([(123, 234), (4, 5, 6)]).execute() +@pytest.mark.asyncio +async def test_declare_unwhitelisted_hint_contract(starknet: Starknet): + with pytest.raises( + PreprocessorError, + match=re.escape( + "This may indicate that this library function cannot be used in StarkNet contracts." + ), + ): + await starknet.declare(source=HINT_CONTRACT_FILE) + + # Check that declare() does not throw an error with disable_hint_validation. + await starknet.declare(source=HINT_CONTRACT_FILE, disable_hint_validation=True) + + @pytest.mark.asyncio async def test_deploy_unwhitelisted_hint_contract(starknet: Starknet): - deployed_contract = await starknet.deploy( - source=HINT_CONTRACT_FILE, disable_hint_validation=True - ) - assert isinstance(deployed_contract.contract_address, int) + with pytest.raises( + PreprocessorError, + match=re.escape( + "This may indicate that this library function cannot be used in StarkNet contracts." + ), + ): + await starknet.deploy(source=HINT_CONTRACT_FILE) + + # Check that deploy() does not throw an error with disable_hint_validation. + await starknet.deploy(source=HINT_CONTRACT_FILE, disable_hint_validation=True) diff --git a/src/starkware/starknet/third_party/open_zeppelin/Account.cairo b/src/starkware/starknet/third_party/open_zeppelin/Account.cairo index f044fbe8..a64ee915 100644 --- a/src/starkware/starknet/third_party/open_zeppelin/Account.cairo +++ b/src/starkware/starknet/third_party/open_zeppelin/Account.cairo @@ -146,6 +146,15 @@ func __validate_declare__{ return (); } +@external +func __validate_deploy__{ + syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, ecdsa_ptr: SignatureBuiltin* +}(class_hash: felt, contract_address_salt: felt, _public_key: felt) { + let (tx_info) = get_tx_info(); + is_valid_signature(tx_info.transaction_hash, tx_info.signature_len, tx_info.signature); + return (); +} + @external func __validate__{ syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, ecdsa_ptr: SignatureBuiltin* diff --git a/src/starkware/starknet/wallets/CMakeLists.txt b/src/starkware/starknet/wallets/CMakeLists.txt index 36ba86c0..d324d2be 100644 --- a/src/starkware/starknet/wallets/CMakeLists.txt +++ b/src/starkware/starknet/wallets/CMakeLists.txt @@ -10,6 +10,7 @@ python_lib(starknet_wallets_lib starknet_definitions_lib starknet_feeder_gateway_client_lib starknet_gateway_client_lib + starknet_transaction_lib starkware_crypto_lib ) @@ -29,9 +30,9 @@ python_lib(starknet_standard_wallets_lib starknet_contract_class_lib starknet_definitions_lib starknet_feeder_gateway_response_objects_lib + starknet_os_abi_lib starknet_transaction_hash_lib starknet_transaction_lib starknet_wallets_lib starkware_crypto_lib - starkware_error_handling_lib ) diff --git a/src/starkware/starknet/wallets/account.py b/src/starkware/starknet/wallets/account.py index 0479ce0f..f77fb533 100644 --- a/src/starkware/starknet/wallets/account.py +++ b/src/starkware/starknet/wallets/account.py @@ -1,40 +1,43 @@ -import dataclasses from abc import ABC, abstractmethod from typing import Awaitable, Callable, List, Tuple from starkware.starknet.services.api.contract_class import ContractClass +from starkware.starknet.services.api.gateway.transaction import ( + Declare, + DeployAccount, + InvokeFunction, +) from starkware.starknet.wallets.starknet_context import StarknetContext DEFAULT_ACCOUNT_DIR = "~/.starknet_accounts" -@dataclasses.dataclass -class WrappedMethod: - address: int - selector: int - calldata: List[int] - max_fee: int - signature: List[int] - nonce: int - - class Account(ABC): @classmethod @abstractmethod - async def create(cls, starknet_context: StarknetContext, account_name: str) -> "Account": + def create(cls, starknet_context: StarknetContext, account_name: str) -> "Account": """ Constructs an instance of the class. """ @abstractmethod - async def deploy(self): + def new_account(self) -> int: + """ + Initializes the account. For example, this may include choosing a new random private key. + Returns the contract address of the new account. + """ + + @abstractmethod + async def deploy_account( + self, max_fee: int, version: int, chain_id: int, dry_run: bool = False + ) -> Tuple[DeployAccount, int]: """ - Initializes the account. For example, this may include choosing a new random private key - and deploying the account contract to the network. + Prepares the deployment of the initialized account contract to the network. + Returns the transaction and the new account address. """ @abstractmethod - async def sign_invoke_transaction( + async def invoke( self, contract_address: int, selector: int, @@ -44,12 +47,13 @@ async def sign_invoke_transaction( version: int, nonce_callback: Callable[[int], Awaitable[int]], dry_run: bool = False, - ) -> WrappedMethod: + ) -> InvokeFunction: """ - Given a transaction to execute (or call) within the context of the account, - prepares the required information for invoking it through the account contract. - nonce is the nonce to be used in the transaction. If not specified, the current nonce - is queried from the StarkNet system. + Given a function (contract address, selector, calldata) to invoke (or call) within the + context of the account, prepares the required information for invoking it through the + account contract. + nonce_callback is a callback that gets the address of the contract and returns the next + nonce to use. """ @abstractmethod @@ -63,11 +67,11 @@ async def deploy_contract( max_fee: int, version: int, nonce_callback: Callable[[int], Awaitable[int]], - ) -> Tuple[WrappedMethod, int]: + ) -> Tuple[InvokeFunction, int]: """ Prepares the required information for invoking a contract deployment function through the account contract. - Returns the wrapped method and the deployed contract address. + Returns the signed transaction and the deployed contract address. """ @abstractmethod @@ -79,7 +83,7 @@ async def declare( version: int, nonce_callback: Callable[[int], Awaitable[int]], dry_run: bool = False, - ) -> WrappedMethod: + ) -> Declare: """ Prepares the required information for declaring a contract class through the account contract. diff --git a/src/starkware/starknet/wallets/open_zeppelin.py b/src/starkware/starknet/wallets/open_zeppelin.py index e9a278e2..b43a584c 100644 --- a/src/starkware/starknet/wallets/open_zeppelin.py +++ b/src/starkware/starknet/wallets/open_zeppelin.py @@ -5,22 +5,28 @@ from services.external_api.client import JsonObject from starkware.crypto.signature.signature import get_random_private_key, private_to_stark_key, sign +from starkware.starknet.core.os.class_hash import compute_class_hash from starkware.starknet.core.os.contract_address.contract_address import ( + calculate_contract_address, calculate_contract_address_from_hash, ) from starkware.starknet.core.os.transaction_hash.transaction_hash import ( TransactionHashPrefix, calculate_declare_transaction_hash, + calculate_deploy_account_transaction_hash, calculate_transaction_hash_common, ) -from starkware.starknet.definitions import constants, fields -from starkware.starknet.public.abi import EXECUTE_ENTRY_POINT_SELECTOR, get_selector_from_name +from starkware.starknet.definitions import fields +from starkware.starknet.public.abi import get_selector_from_name from starkware.starknet.services.api.contract_class import ContractClass -from starkware.starknet.services.api.gateway.transaction import Deploy +from starkware.starknet.services.api.gateway.transaction import ( + Declare, + DeployAccount, + InvokeFunction, +) from starkware.starknet.third_party.open_zeppelin.starknet_contracts import account_contract -from starkware.starknet.wallets.account import Account, WrappedMethod +from starkware.starknet.wallets.account import Account from starkware.starknet.wallets.starknet_context import StarknetContext -from starkware.starkware_utils.error_handling import StarkErrorCode ACCOUNT_FILE_NAME = "starknet_open_zeppelin_accounts.json" DEPLOY_CONTRACT_SELECTOR = get_selector_from_name("deploy_contract") @@ -37,9 +43,7 @@ def __init__(self, starknet_context: StarknetContext, account_name: str): self.starknet_context = starknet_context @classmethod - async def create( - cls, starknet_context: StarknetContext, account_name: str - ) -> "OpenZeppelinAccount": + def create(cls, starknet_context: StarknetContext, account_name: str) -> "OpenZeppelinAccount": return cls(starknet_context=starknet_context, account_name=account_name) @property @@ -56,18 +60,9 @@ async def declare( version: int, nonce_callback: Callable[[int], Awaitable[int]], dry_run: bool = False, - ) -> WrappedMethod: - account = self.get_account_information() - account_address = int(account["address"], 16) - - private_key: Optional[int] - if "private_key" in account: - private_key = int(account["private_key"], 16) - else: - assert dry_run, f"Missing private key for {hex(account_address)}" - private_key = None - - return sign_declare_transaction( + ) -> Declare: + account_address, private_key = self._get_account_address_and_private_key(dry_run=dry_run) + return sign_declare_tx( contract_class=contract_class, private_key=private_key, sender_address=account_address, @@ -77,17 +72,30 @@ async def declare( nonce=await nonce_callback(account_address), ) - async def deploy(self): + def _get_accounts(self) -> dict: # Read the account file. if os.path.exists(self.account_file): + # First, load the file, and make sure it's in JSON format. + accounts = json.load(open(self.account_file)) # Make a backup of the file. shutil.copy(self.account_file, self.account_file + ".backup") - accounts = json.load(open(self.account_file)) else: accounts = {} + return accounts - accounts_for_network = accounts.setdefault(self.starknet_context.network_id, {}) + def _get_account_given_accounts(self, accounts: dict) -> JsonObject: + accounts_for_network = accounts.get(self.starknet_context.network_id, {}) + if self.account_name not in accounts_for_network: + raise AccountNotFoundException( + f"Account '{self.account_name}' for network '{self.starknet_context.network_id}' " + "was not found. You can create a new account using the 'new_account' command." + ) + return accounts_for_network[self.account_name] + def new_account(self) -> int: + # Read the account file. + accounts = self._get_accounts() + accounts_for_network = accounts.setdefault(self.starknet_context.network_id, {}) assert self.account_name not in accounts_for_network, ( f"Account '{self.account_name}' for network '{self.starknet_context.network_id}' " "already exists." @@ -95,84 +103,103 @@ async def deploy(self): private_key = get_random_private_key() public_key = private_to_stark_key(private_key) - - # Deploy the contract. salt = fields.ContractAddressSalt.get_random_value() - - tx = Deploy( - contract_address_salt=salt, - contract_definition=account_contract, + contract_address = calculate_contract_address( + salt=salt, + contract_class=account_contract, constructor_calldata=[public_key], - version=constants.TRANSACTION_VERSION, + deployer_address=0, ) - gateway_response = await self.starknet_context.gateway_client.add_transaction(tx=tx) - assert ( - gateway_response["code"] == StarkErrorCode.TRANSACTION_RECEIVED.name - ), f"Failed to send deploy transaction. Response: {gateway_response}." - contract_address = int(gateway_response["address"], 16) - accounts_for_network[self.account_name] = { "private_key": hex(private_key), "public_key": hex(public_key), + "salt": hex(salt), "address": hex(contract_address), + "deployed": False, } # Don't end sentences with '.', to allow easy double-click copy-pasting of the values. print( f"""\ -Sent deploy account contract transaction. +Account address: 0x{contract_address:064x} +Public key: 0x{public_key:064x} +Move the appropriate amount of funds to the account, and then deploy the account +by invoking the 'starknet deploy_account' command. NOTE: This is a modified version of the OpenZeppelin account contract. The signature is computed differently. - -Contract address: 0x{contract_address:064x} -Public key: 0x{public_key:064x} -Transaction hash: {gateway_response['transaction_hash']} """ ) + os.makedirs(name=os.path.dirname(self.account_file), exist_ok=True) with open(self.account_file, "w") as f: json.dump(accounts, f, indent=4) f.write("\n") - def get_account_information(self) -> JsonObject: + return contract_address + + async def deploy_account( + self, max_fee: int, version: int, chain_id: int, dry_run: bool = False + ) -> Tuple[DeployAccount, int]: + # Read the account file. + accounts = self._get_accounts() + account_to_deploy = self._get_account_given_accounts(accounts=accounts) + tx = sign_deploy_account_tx( + private_key=int(account_to_deploy["private_key"], 16), + public_key=int(account_to_deploy["public_key"], 16), + class_hash=compute_class_hash(account_contract), + salt=int(account_to_deploy["salt"], 16), + max_fee=max_fee, + version=version, + chain_id=chain_id, + ) + contract_address = int(account_to_deploy["address"], 16) + + if dry_run: + return tx, contract_address + + assert account_to_deploy["deployed"] is False, ( + f"Account '{self.account_name}' for network '{self.starknet_context.network_id}' " + "is already deployed." + ) + account_to_deploy["deployed"] = True + os.makedirs(name=os.path.dirname(self.account_file), exist_ok=True) + with open(self.account_file, "w") as f: + json.dump(accounts, f, indent=4) + f.write("\n") + + return tx, contract_address + + def _get_deployed_account_info(self) -> JsonObject: assert os.path.exists(self.account_file), ( f"The account file '{self.account_file}' was not found.\n" - "Did you deploy your account contract (using 'starnet deploy_account')?" + "Did you deploy your account contract (using 'starknet new_account' " + "and 'starknet deploy_account')?" ) - accounts = json.load(open(self.account_file)) - accounts_for_network = accounts.get(self.starknet_context.network_id, {}) - if self.account_name not in accounts_for_network: - raise AccountNotFoundException( - f"Account '{self.account_name}' for network '{self.starknet_context.network_id}' " - "was not found." - ) - return accounts_for_network[self.account_name] + accounts = self._get_accounts() + account = self._get_account_given_accounts(accounts=accounts) + assert account["deployed"], ( + f"Account '{self.account_name}' for network '{self.starknet_context.network_id}' " + "is not deployed; use 'starknet deploy_account' command." + ) - async def sign_invoke_transaction( + return account + + async def invoke( self, contract_address: int, selector: int, calldata: List[int], chain_id: int, - max_fee: Optional[int], + max_fee: int, version: int, nonce_callback: Callable[[int], Awaitable[int]], dry_run: bool = False, - ) -> WrappedMethod: - account = self.get_account_information() - account_address = int(account["address"], 16) - - private_key: Optional[int] - if "private_key" in account: - private_key = int(account["private_key"], 16) - else: - assert dry_run, f"Missing private_key for {hex(account_address)}." - private_key = None - - return sign_invoke_transaction( + ) -> InvokeFunction: + account_address, private_key = self._get_account_address_and_private_key(dry_run=dry_run) + return sign_invoke_tx( signer_address=account_address, private_key=private_key, contract_address=contract_address, @@ -191,11 +218,11 @@ async def deploy_contract( constructor_calldata: List[int], deploy_from_zero: bool, chain_id: int, - max_fee: Optional[int], + max_fee: int, version: int, nonce_callback: Callable[[int], Awaitable[int]], - ) -> Tuple[WrappedMethod, int]: - account = self.get_account_information() + ) -> Tuple[InvokeFunction, int]: + account = self._get_deployed_account_info() account_address = int(account["address"], 16) deploy_from_zero_felt = 1 if deploy_from_zero else 0 calldata = [ @@ -206,7 +233,7 @@ async def deploy_contract( deploy_from_zero_felt, ] - wrapped_invocation = await self.sign_invoke_transaction( + tx = await self.invoke( contract_address=account_address, selector=DEPLOY_CONTRACT_SELECTOR, calldata=calldata, @@ -223,10 +250,23 @@ async def deploy_contract( deployer_address=0 if deploy_from_zero else account_address, ) - return wrapped_invocation, contract_address + return tx, contract_address + + def _get_account_address_and_private_key(self, dry_run: bool) -> Tuple[int, Optional[int]]: + account = self._get_deployed_account_info() + account_address = int(account["address"], 16) + + private_key: Optional[int] + if "private_key" in account: + private_key = int(account["private_key"], 16) + else: + assert dry_run, f"Missing private_key for {hex(account_address)}." + private_key = None + + return account_address, private_key -def sign_declare_transaction( +def sign_declare_tx( contract_class: ContractClass, private_key: Optional[int], sender_address: int, @@ -234,7 +274,7 @@ def sign_declare_transaction( max_fee: int, version: int, nonce: int, -) -> WrappedMethod: +) -> Declare: hash_value = calculate_declare_transaction_hash( contract_class=contract_class, chain_id=chain_id, @@ -243,41 +283,41 @@ def sign_declare_transaction( version=version, nonce=nonce, ) - if private_key is None: - signature = [] - else: - signature = list(sign(msg_hash=hash_value, priv_key=private_key)) - return WrappedMethod( - address=sender_address, - selector=0, - calldata=[], + + return Declare( + contract_class=contract_class, + sender_address=sender_address, max_fee=max_fee, - signature=signature, + signature=( + [] if private_key is None else list(sign(msg_hash=hash_value, priv_key=private_key)) + ), nonce=nonce, + version=version, ) -def sign_invoke_transaction( +def sign_invoke_tx( signer_address: int, private_key: Optional[int], contract_address: int, selector: int, calldata: List[int], chain_id: int, - max_fee: Optional[int], + max_fee: int, version: int, nonce: int, -) -> WrappedMethod: +) -> InvokeFunction: """ - Calculates the transaction's hash and then computes the signature using the private key. - Returns a WrappedMethod with the signature of the sender. + Given a function to invoke (contract address, selector, calldata) and account identifiers + (signer address, private key) prepares and signs an OpenZeppelin account invocation to this + function. """ data_offset = 0 data_len = len(calldata) call_entry = [contract_address, selector, data_offset, data_len] call_array_len = 1 wrapped_method_calldata = [call_array_len, *call_entry, len(calldata), *calldata] - max_fee = 0 if max_fee is None else max_fee + hash_value = calculate_transaction_hash_common( tx_hash_prefix=TransactionHashPrefix.INVOKE, version=version, @@ -288,15 +328,54 @@ def sign_invoke_transaction( chain_id=chain_id, additional_data=[nonce], ) - if private_key is None: - signature = [] - else: - signature = list(sign(msg_hash=hash_value, priv_key=private_key)) - return WrappedMethod( - address=signer_address, - selector=EXECUTE_ENTRY_POINT_SELECTOR, + + return InvokeFunction( + contract_address=signer_address, calldata=wrapped_method_calldata, max_fee=max_fee, nonce=nonce, - signature=signature, + signature=( + [] if private_key is None else list(sign(msg_hash=hash_value, priv_key=private_key)) + ), + version=version, + ) + + +def sign_deploy_account_tx( + private_key: Optional[int], + public_key: int, + class_hash: int, + salt: int, + max_fee: int, + version: int, + chain_id: int, + nonce: int = 0, +) -> DeployAccount: + contract_address = calculate_contract_address_from_hash( + salt=salt, + class_hash=class_hash, + constructor_calldata=[public_key], + deployer_address=0, + ) + hash_value = calculate_deploy_account_transaction_hash( + contract_address=contract_address, + class_hash=class_hash, + constructor_calldata=[public_key], + salt=salt, + max_fee=max_fee, + version=version, + chain_id=chain_id, + nonce=nonce, + ) + + return DeployAccount( + class_hash=class_hash, + constructor_calldata=[public_key], + contract_address_salt=salt, + max_fee=max_fee, + nonce=nonce, + signature=( + [] if private_key is None else list(sign(msg_hash=hash_value, priv_key=private_key)) + ), + version=version, ) diff --git a/src/starkware/starkware_utils/error_handling.py b/src/starkware/starkware_utils/error_handling.py index cdf7da8f..7b216bb0 100644 --- a/src/starkware/starkware_utils/error_handling.py +++ b/src/starkware/starkware_utils/error_handling.py @@ -156,7 +156,7 @@ class StarkException(WebFriendlyException): an invalid transaction). """ - def __init__(self, code, message: Optional[str] = None): + def __init__(self, code: ErrorCode, message: Optional[str] = None): self.code = code self.message = message super().__init__(status_code=500, body={"code": code, "message": message}) @@ -252,12 +252,15 @@ def wrap_with_stark_exception( try: yield + except StarkException: + # Raise StarkException-s as-is, so failure information is not lost. + raise except tuple(exception_types) as exception: message = str(exception) if message is None else message if logger is not None: logger.error(message, exc_info=True) - raise StarkException(code=code, message=message) + raise StarkException(code=code, message=message) from exception @dataclasses.dataclass(frozen=True) diff --git a/src/starkware/starkware_utils/marshmallow_dataclass_fields.py b/src/starkware/starkware_utils/marshmallow_dataclass_fields.py index ae41c38d..bb3a8e87 100644 --- a/src/starkware/starkware_utils/marshmallow_dataclass_fields.py +++ b/src/starkware/starkware_utils/marshmallow_dataclass_fields.py @@ -2,7 +2,8 @@ import functools import re from abc import ABC, abstractmethod -from typing import Any, Callable, Dict +from enum import Enum +from typing import Any, Callable, Dict, Type import marshmallow.fields as mfields from frozendict import frozendict @@ -60,7 +61,9 @@ class EnumField(mfields.Field): A field that behaves like an enum, but serializes to a string. """ - def __init__(self, enum_cls, required: bool = False, allow_none: bool = False, **kwargs): + def __init__( + self, enum_cls: Type[Enum], required: bool = False, allow_none: bool = False, **kwargs + ): self.enum_cls = enum_cls super().__init__(required=required, allow_none=allow_none, **kwargs) From abe37822cb264c9298f6c35df769f4e29ce82127 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20B=C3=A4ttig?= Date: Wed, 5 Oct 2022 20:57:38 +0200 Subject: [PATCH 2/4] cli raw input support --- src/starkware/starknet/cli/starknet_cli.py | 240 +++++++++++++-------- 1 file changed, 153 insertions(+), 87 deletions(-) diff --git a/src/starkware/starknet/cli/starknet_cli.py b/src/starkware/starknet/cli/starknet_cli.py index 24d321f1..37e92a32 100755 --- a/src/starkware/starknet/cli/starknet_cli.py +++ b/src/starkware/starknet/cli/starknet_cli.py @@ -141,7 +141,8 @@ def get_arg_value(args, arg_name: str, environment_var: str) -> str: Same as get_optional_arg_value, except that if the value is not defined, an exception is raised. """ - value = get_optional_arg_value(args=args, arg_name=arg_name, environment_var=environment_var) + value = get_optional_arg_value( + args=args, arg_name=arg_name, environment_var=environment_var) if value is None: raise Exception( f'{arg_name} must be specified with the "{args.command}" subcommand.\n' @@ -151,14 +152,16 @@ def get_arg_value(args, arg_name: str, environment_var: str) -> str: def get_chain_id(args) -> int: - chain_id = get_arg_value(args=args, arg_name="chain_id", environment_var="STARKNET_CHAIN_ID") + chain_id = get_arg_value( + args=args, arg_name="chain_id", environment_var="STARKNET_CHAIN_ID") if chain_id.startswith("0x"): chain_id_int = int(chain_id, 16) else: chain_id_int = from_bytes(chain_id.encode()) - assert chain_id_int in CHAIN_IDS.values(), f"Unsupported chain ID: {chain_id}." + assert chain_id_int in CHAIN_IDS.values( + ), f"Unsupported chain ID: {chain_id}." return chain_id_int @@ -174,7 +177,8 @@ def get_wallet_provider(args) -> Optional[str]: """ Returns the name of the wallet provider (of the form "module.class") as defined by the user. """ - value = get_optional_arg_value(args=args, arg_name="wallet", environment_var="STARKNET_WALLET") + value = get_optional_arg_value( + args=args, arg_name="wallet", environment_var="STARKNET_WALLET") assert value is not None, ( "A wallet must be specified (using --wallet or the STARKNET_WALLET environment variable), " "unless specifically using --no_wallet." @@ -241,7 +245,8 @@ def parse_hex_arg(arg: str, arg_name: str) -> int: Converts the given argument (hex string, starting with "0x") to an integer. """ arg = arg.strip() - assert arg.startswith("0x"), f"{arg_name} must start with '0x'. Got: '{arg}'." + assert arg.startswith( + "0x"), f"{arg_name} must start with '0x'. Got: '{arg}'." try: return int(arg, 16) except ValueError: @@ -281,7 +286,8 @@ async def compute_max_fee( tx=tx, has_block_info=False, ) - max_fee = math.ceil(simulate_tx_info.fee_estimation.overall_fee * FEE_MARGIN_OF_ESTIMATION) + max_fee = math.ceil( + simulate_tx_info.fee_estimation.overall_fee * FEE_MARGIN_OF_ESTIMATION) max_fee_eth = float(Web3.fromWei(max_fee, "ether")) print(f"Sending the transaction with max_fee: {max_fee_eth:.6f} ETH.") @@ -310,7 +316,8 @@ def validate_arguments( try: typ = mark_type_resolved(parse_type(input_desc["type"])) - typ_size = check_felts_only_type(cairo_type=typ, identifier_manager=identifier_manager) + typ_size = check_felts_only_type( + cairo_type=typ, identifier_manager=identifier_manager) except Exception as ex: raise AbiFormatError(ex) from ex @@ -329,7 +336,8 @@ def validate_arguments( raise AbiFormatError(ex) from ex if typ_size is None: - raise AbiFormatError(ABI_TYPE_NOT_SUPPORTED_ERROR_FORMAT.format(typ=typ.format())) + raise AbiFormatError( + ABI_TYPE_NOT_SUPPORTED_ERROR_FORMAT.format(typ=typ.format())) assert previous_felt_input is not None, ( f"The array argument {input_desc['name']} of type felt* must be preceded " "by a length argument of type felt." @@ -337,8 +345,10 @@ def validate_arguments( current_inputs_ptr += previous_felt_input * typ_size else: - raise AbiFormatError(ABI_TYPE_NOT_SUPPORTED_ERROR_FORMAT.format(typ=typ.format())) - previous_felt_input = inputs[current_inputs_ptr - 1] if typ == TypeFelt() else None + raise AbiFormatError( + ABI_TYPE_NOT_SUPPORTED_ERROR_FORMAT.format(typ=typ.format())) + previous_felt_input = inputs[current_inputs_ptr - + 1] if typ == TypeFelt() else None assert ( len(inputs) == current_inputs_ptr @@ -450,6 +460,8 @@ def validate_call_function_args( ) if abi_entry["name"] == args.function: + if args.raw: + break validate_arguments( inputs=inputs, abi_entry=abi_entry, @@ -467,7 +479,8 @@ def parse_call_function_args(args: argparse.Namespace) -> CallFunction: Parses the arguments and validates that the function name is in the ABI. """ inputs = cast_to_felts(values=args.inputs) - validate_call_function_args(args=args, abi_entry_type="function", inputs=inputs) + validate_call_function_args( + args=args, abi_entry_type="function", inputs=inputs) return CallFunction( contract_address=parse_hex_arg(arg=args.address, arg_name="address"), @@ -481,7 +494,8 @@ def parse_call_l1_handler_args(args: argparse.Namespace) -> CallL1Handler: Parses the arguments and validates that the l1_handler name is in the ABI. """ inputs = cast_to_felts(values=args.inputs) - from_address = parse_hex_arg(arg=args.from_address, arg_name="from_address") + from_address = parse_hex_arg( + arg=args.from_address, arg_name="from_address") validate_call_function_args( args=args, abi_entry_type="l1_handler", inputs=[from_address] + inputs ) @@ -508,7 +522,8 @@ def parse_invoke_tx_args(args: argparse.Namespace) -> InvokeFunctionArgs: def parse_declare_tx_args(args: argparse.Namespace) -> DeclareArgs: validate_max_fee(max_fee=args.max_fee) - sender = parse_hex_arg(arg=args.sender, arg_name="sender") if args.sender is not None else None + sender = parse_hex_arg( + arg=args.sender, arg_name="sender") if args.sender is not None else None return DeclareArgs( sender=sender, signature=cast_to_felts(values=args.signature), @@ -720,7 +735,8 @@ async def simulate_transaction(args: argparse.Namespace, tx: AccountTransaction) async def simulate_or_estimate_fee(args: argparse.Namespace, tx: AccountTransaction): - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) if args.simulate: await simulate_transaction(args=args, tx=tx) else: @@ -740,7 +756,8 @@ async def declare( valid signature must be provided as arguments. """ - parser = argparse.ArgumentParser(description="Sends a declare transaction to StarkNet.") + parser = argparse.ArgumentParser( + description="Sends a declare transaction to StarkNet.") add_declare_tx_arguments(parser=parser) parser.parse_args(command_args, namespace=args) declare_tx_args = parse_declare_tx_args(args=args) @@ -786,46 +803,10 @@ async def declare( async def deploy(args, command_args): - parser = argparse.ArgumentParser(description="Deploys a contract to StarkNet.") - parser.add_argument( - "--salt", - type=str, - help=( - "An optional salt controlling where the contract will be deployed. " - "The contract deployment address is determined by the hash " - "of contract, salt and caller. " - "If the salt is not supplied, the contract will be deployed with a random salt." - ), - ) - parser.add_argument( - "--inputs", type=str, nargs="*", default=[], help="The inputs to the constructor." - ) - parser.add_argument( - "--token", type=str, help="Used for deploying contracts in Alpha MainNet.", required=False - ) - parser.add_argument("--class_hash", type=str, help="The class hash of the deployed contract.") - parser.add_argument( - "--nonce", - type=int, - help=( - "Used for explicitly specifying the transaction nonce. " - "If not specified, the current nonce of the account contract " - "(as returned from StarkNet) will be used." - ), - ) - parser.add_argument( - "--max_fee", type=int, help="The maximal fee to be paid for the deployment." - ) - parser.add_argument( - "--contract", type=argparse.FileType("r"), help="The contract class to deploy." - ) - parser.add_argument( - "--deploy_from_zero", - action="store_true", - help="Use 0 instead of the deployer address for the contract address computation.", - ) + parser = argparse.ArgumentParser( + description="Deploys a contract to StarkNet.") + add_deploy_tx_arguments(parser=parser) parser.parse_args(command_args, namespace=args) - has_wallet = get_wallet_provider(args=args) is not None if has_wallet: assert args.contract is None, ( @@ -861,6 +842,8 @@ async def deploy_tx(args): if abi_entry["type"] == "constructor": try: + if args.raw: + break validate_arguments( inputs=inputs, abi_entry=abi_entry, @@ -872,7 +855,8 @@ async def deploy_tx(args): f"Failed to parse the contract ABI: {abi_error}" ) from abi_error else: - assert len(inputs) == 0, "--inputs cannot be specified for contracts without a constructor." + assert len( + inputs) == 0, "--inputs cannot be specified for contracts without a constructor." tx = Deploy( contract_address_salt=salt, @@ -941,14 +925,16 @@ async def deploy_account(args, command_args): async def call(args: argparse.Namespace, command_args: List[str]): - parser = argparse.ArgumentParser(description="Calls a function on a StarkNet contract.") + parser = argparse.ArgumentParser( + description="Calls a function on a StarkNet contract.") add_call_function_arguments(parser=parser) add_block_identifier_arguments( parser=parser, block_role_description="be used as the context for the call operation" ) parser.parse_args(command_args, namespace=args) call_function_args = parse_call_function_args(args=args) - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) feeder_client = get_feeder_gateway_client(args) gateway_response = await feeder_client.call_contract( @@ -958,7 +944,8 @@ async def call(args: argparse.Namespace, command_args: List[str]): async def invoke(args: argparse.Namespace, command_args: List[str]): - parser = argparse.ArgumentParser(description="Sends an invoke transaction to StarkNet.") + parser = argparse.ArgumentParser( + description="Sends an invoke transaction to StarkNet.") add_invoke_tx_arguments(parser=parser) parser.parse_args(command_args, namespace=args) invoke_tx_args = parse_invoke_tx_args(args=args) @@ -1017,7 +1004,8 @@ async def invoke(args: argparse.Namespace, command_args: List[str]): def print_invoke_tx(tx: InvokeFunction, chain_id: int): sn_config_dict = StarknetGeneralConfig().dump() - sn_config_dict["starknet_os_config"]["chain_id"] = StarknetChainId(chain_id).name + sn_config_dict["starknet_os_config"]["chain_id"] = StarknetChainId( + chain_id).name sn_config = StarknetGeneralConfig.load(sn_config_dict) tx_hash = tx.calculate_hash(sn_config) out_dict = { @@ -1028,14 +1016,16 @@ def print_invoke_tx(tx: InvokeFunction, chain_id: int): async def estimate_message_fee(args: argparse.Namespace, command_args: List[str]): - parser = argparse.ArgumentParser(description="Estimates the fee of an L1-to-L2 message.") + parser = argparse.ArgumentParser( + description="Estimates the fee of an L1-to-L2 message.") add_block_identifier_arguments( parser=parser, block_role_description="be used as the context for the call operation" ) add_call_l1_handler_arguments(parser=parser) parser.parse_args(command_args, namespace=args) - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) call_l1_handler = parse_call_l1_handler_args(args=args) feeder_client = get_feeder_gateway_client(args=args) @@ -1090,7 +1080,8 @@ async def tx_status(args, command_args): else: addr_str, path = addr_and_path_split addr = parse_hex_arg(arg=addr_str, arg_name="address") - contracts[addr] = Program.load(data=json.load(open(path.strip()))["program"]) + contracts[addr] = Program.load( + data=json.load(open(path.strip()))["program"]) error_message = reconstruct_starknet_traceback( contracts=contracts, traceback_txt=error_message ) @@ -1117,7 +1108,8 @@ async def get_transaction(args, command_args): async def get_transaction_trace(args, command_args): - parser = argparse.ArgumentParser(description="Outputs the transaction trace given its hash.") + parser = argparse.ArgumentParser( + description="Outputs the transaction trace given its hash.") parser.add_argument( "--hash", type=str, required=True, help="The hash of the transaction to query." ) @@ -1129,7 +1121,8 @@ async def get_transaction_trace(args, command_args): async def get_transaction_receipt(args, command_args): - parser = argparse.ArgumentParser(description="Outputs the transaction receipt given its hash.") + parser = argparse.ArgumentParser( + description="Outputs the transaction receipt given its hash.") parser.add_argument( "--hash", type=str, required=True, help="The hash of the transaction to query." ) @@ -1185,8 +1178,10 @@ async def get_block_traces(args, command_args): async def get_state_update(args, command_args): - parser = argparse.ArgumentParser(description=("Outputs the state update of a given block")) - add_block_identifier_arguments(parser=parser, block_role_description="display") + parser = argparse.ArgumentParser(description=( + "Outputs the state update of a given block")) + add_block_identifier_arguments( + parser=parser, block_role_description="display") parser.parse_args(command_args, namespace=args) args.block_hash, args.block_number = parse_block_identifiers( @@ -1212,14 +1207,17 @@ async def get_code(args, command_args): parser.add_argument( "--contract_address", type=str, help="The address of the contract.", required=True ) - add_block_identifier_arguments(parser=parser, block_role_description="extract information from") + add_block_identifier_arguments( + parser=parser, block_role_description="extract information from") parser.parse_args(command_args, namespace=args) - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) feeder_gateway_client = get_feeder_gateway_client(args) code = await feeder_gateway_client.get_code( - contract_address=parse_hex_arg(arg=args.contract_address, arg_name="contract address"), + contract_address=parse_hex_arg( + arg=args.contract_address, arg_name="contract address"), block_hash=args.block_hash, block_number=args.block_number, ) @@ -1251,14 +1249,17 @@ async def get_full_contract(args, command_args): parser.add_argument( "--contract_address", type=str, help="The address of the contract.", required=True ) - add_block_identifier_arguments(parser=parser, block_role_description="extract information from") + add_block_identifier_arguments( + parser=parser, block_role_description="extract information from") parser.parse_args(command_args, namespace=args) - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) feeder_gateway_client = get_feeder_gateway_client(args) contract_class = await feeder_gateway_client.get_full_contract( - contract_address=parse_hex_arg(arg=args.contract_address, arg_name="contract address"), + contract_address=parse_hex_arg( + arg=args.contract_address, arg_name="contract address"), block_hash=args.block_hash, block_number=args.block_number, ) @@ -1275,14 +1276,17 @@ async def get_class_hash_at(args, command_args): parser.add_argument( "--contract_address", type=str, help="The address of the contract.", required=True ) - add_block_identifier_arguments(parser=parser, block_role_description="extract information from") + add_block_identifier_arguments( + parser=parser, block_role_description="extract information from") parser.parse_args(command_args, namespace=args) - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) feeder_gateway_client = get_feeder_gateway_client(args) class_hash = await feeder_gateway_client.get_class_hash_at( - contract_address=parse_hex_arg(arg=args.contract_address, arg_name="contract address"), + contract_address=parse_hex_arg( + arg=args.contract_address, arg_name="contract address"), block_hash=args.block_hash, block_number=args.block_number, ) @@ -1290,7 +1294,8 @@ async def get_class_hash_at(args, command_args): async def get_contract_addresses(args, command_args): - argparse.ArgumentParser(description="Outputs the addresses of the StarkNet system contracts.") + argparse.ArgumentParser( + description="Outputs the addresses of the StarkNet system contracts.") feeder_gateway_client = get_feeder_gateway_client(args) contract_addresses = await feeder_gateway_client.get_contract_addresses() @@ -1307,14 +1312,17 @@ async def get_nonce(args, command_args): parser.add_argument( "--contract_address", type=str, help="The address of the contract.", required=True ) - add_block_identifier_arguments(parser=parser, block_role_description="extract information from") + add_block_identifier_arguments( + parser=parser, block_role_description="extract information from") parser.parse_args(command_args, namespace=args) - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) feeder_gateway_client = get_feeder_gateway_client(args) nonce = await feeder_gateway_client.get_nonce( - contract_address=parse_hex_arg(args.contract_address, "contract_address"), + contract_address=parse_hex_arg( + args.contract_address, "contract_address"), block_hash=args.block_hash, block_number=args.block_number, ) @@ -1334,15 +1342,18 @@ async def get_storage_at(args, command_args): parser.add_argument( "--key", type=int, help="The position in the contract's storage.", required=True ) - add_block_identifier_arguments(parser=parser, block_role_description="extract information from") + add_block_identifier_arguments( + parser=parser, block_role_description="extract information from") parser.parse_args(command_args, namespace=args) - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) feeder_gateway_client = get_feeder_gateway_client(args) print( await feeder_gateway_client.get_storage_at( - contract_address=parse_hex_arg(arg=args.contract_address, arg_name="contract address"), + contract_address=parse_hex_arg( + arg=args.contract_address, arg_name="contract address"), key=args.key, block_hash=args.block_hash, block_number=args.block_number, @@ -1432,6 +1443,9 @@ def add_call_function_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--inputs", type=str, nargs="*", default=[], help="The inputs to the invoked function." ) + parser.add_argument( + "--raw", action="store_true", help="Function inputs are in raw format." + ) def add_call_l1_handler_arguments(parser: argparse.ArgumentParser): @@ -1486,7 +1500,8 @@ def add_block_identifier_arguments( parser: argparse.ArgumentParser, block_role_description: str, with_block_prefix: bool = True ): identifier_prefix = "block_" if with_block_prefix else "" - block_identifier_parser_group = parser.add_mutually_exclusive_group(required=False) + block_identifier_parser_group = parser.add_mutually_exclusive_group( + required=False) block_identifier_parser_group.add_argument( f"--{identifier_prefix}hash", type=str, @@ -1501,6 +1516,53 @@ def add_block_identifier_arguments( ) +def add_deploy_tx_arguments(parser: argparse.ArgumentParser): + """ + Adds the arguments: salt, inputs, raw, token, class_hash, nonce, max_fee, contract, deploy_from_zero. + """ + parser.add_argument( + "--salt", + type=str, + help=( + "An optional salt controlling where the contract will be deployed. " + "The contract deployment address is determined by the hash " + "of contract, salt and caller. " + "If the salt is not supplied, the contract will be deployed with a random salt." + ), + ) + parser.add_argument( + "--inputs", type=str, nargs="*", default=[], help="The inputs to the constructor." + ) + parser.add_argument( + "--raw", action="store_true", help="Constructor inputs are in raw format." + ) + parser.add_argument( + "--token", type=str, help="Used for deploying contracts in Alpha MainNet.", required=False + ) + parser.add_argument("--class_hash", type=str, + help="The class hash of the deployed contract.") + parser.add_argument( + "--nonce", + type=int, + help=( + "Used for explicitly specifying the transaction nonce. " + "If not specified, the current nonce of the account contract " + "(as returned from StarkNet) will be used." + ), + ) + parser.add_argument( + "--max_fee", type=int, help="The maximal fee to be paid for the deployment." + ) + parser.add_argument( + "--contract", type=argparse.FileType("r"), help="The contract class to deploy." + ) + parser.add_argument( + "--deploy_from_zero", + action="store_true", + help="Use 0 instead of the deployer address for the contract address computation.", + ) + + async def main(): subparsers = { "call": call, @@ -1524,9 +1586,12 @@ async def main(): "invoke": invoke, "tx_status": tx_status, } - parser = argparse.ArgumentParser(description="A tool to communicate with StarkNet.") - parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}") - parser.add_argument("--network", type=str, help="The name of the StarkNet network.") + parser = argparse.ArgumentParser( + description="A tool to communicate with StarkNet.") + parser.add_argument("-v", "--version", action="version", + version=f"%(prog)s {__version__}") + parser.add_argument("--network", type=str, + help="The name of the StarkNet network.") parser.add_argument( "--network_id", type=str, @@ -1578,7 +1643,8 @@ async def main(): help="Print the full Python error trace in case of an internal error.", ) - parser.add_argument("--gateway_url", type=str, help="The URL of a StarkNet gateway.") + parser.add_argument("--gateway_url", type=str, + help="The URL of a StarkNet gateway.") parser.add_argument( "--feeder_gateway_url", type=str, help="The URL of a StarkNet feeder gateway." ) From 6efe206c26be7a49a9e2415984d9321966d9c7d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20B=C3=A4ttig?= Date: Wed, 5 Oct 2022 20:57:38 +0200 Subject: [PATCH 3/4] cli raw input support --- src/starkware/starknet/cli/starknet_cli.py | 240 +++++++++++++-------- 1 file changed, 153 insertions(+), 87 deletions(-) diff --git a/src/starkware/starknet/cli/starknet_cli.py b/src/starkware/starknet/cli/starknet_cli.py index fd95d747..b67ecd0e 100755 --- a/src/starkware/starknet/cli/starknet_cli.py +++ b/src/starkware/starknet/cli/starknet_cli.py @@ -136,7 +136,8 @@ def get_arg_value(args, arg_name: str, environment_var: str) -> str: Same as get_optional_arg_value, except that if the value is not defined, an exception is raised. """ - value = get_optional_arg_value(args=args, arg_name=arg_name, environment_var=environment_var) + value = get_optional_arg_value( + args=args, arg_name=arg_name, environment_var=environment_var) if value is None: raise Exception( f'{arg_name} must be specified with the "{args.command}" subcommand.\n' @@ -146,14 +147,16 @@ def get_arg_value(args, arg_name: str, environment_var: str) -> str: def get_chain_id(args) -> int: - chain_id = get_arg_value(args=args, arg_name="chain_id", environment_var="STARKNET_CHAIN_ID") + chain_id = get_arg_value( + args=args, arg_name="chain_id", environment_var="STARKNET_CHAIN_ID") if chain_id.startswith("0x"): chain_id_int = int(chain_id, 16) else: chain_id_int = from_bytes(chain_id.encode()) - assert chain_id_int in CHAIN_IDS.values(), f"Unsupported chain ID: {chain_id}." + assert chain_id_int in CHAIN_IDS.values( + ), f"Unsupported chain ID: {chain_id}." return chain_id_int @@ -169,7 +172,8 @@ def get_wallet_provider(args) -> Optional[str]: """ Returns the name of the wallet provider (of the form "module.class") as defined by the user. """ - value = get_optional_arg_value(args=args, arg_name="wallet", environment_var="STARKNET_WALLET") + value = get_optional_arg_value( + args=args, arg_name="wallet", environment_var="STARKNET_WALLET") assert value is not None, ( "A wallet must be specified (using --wallet or the STARKNET_WALLET environment variable), " "unless specifically using --no_wallet." @@ -236,7 +240,8 @@ def parse_hex_arg(arg: str, arg_name: str) -> int: Converts the given argument (hex string, starting with "0x") to an integer. """ arg = arg.strip() - assert arg.startswith("0x"), f"{arg_name} must start with '0x'. Got: '{arg}'." + assert arg.startswith( + "0x"), f"{arg_name} must start with '0x'. Got: '{arg}'." try: return int(arg, 16) except ValueError: @@ -276,7 +281,8 @@ async def compute_max_fee( tx=tx, has_block_info=False, ) - max_fee = math.ceil(simulate_tx_info.fee_estimation.overall_fee * FEE_MARGIN_OF_ESTIMATION) + max_fee = math.ceil( + simulate_tx_info.fee_estimation.overall_fee * FEE_MARGIN_OF_ESTIMATION) max_fee_eth = float(Web3.fromWei(max_fee, "ether")) print(f"Sending the transaction with max_fee: {max_fee_eth:.6f} ETH.") @@ -305,7 +311,8 @@ def validate_arguments( try: typ = mark_type_resolved(parse_type(input_desc["type"])) - typ_size = check_felts_only_type(cairo_type=typ, identifier_manager=identifier_manager) + typ_size = check_felts_only_type( + cairo_type=typ, identifier_manager=identifier_manager) except Exception as ex: raise AbiFormatError(ex) from ex @@ -324,7 +331,8 @@ def validate_arguments( raise AbiFormatError(ex) from ex if typ_size is None: - raise AbiFormatError(ABI_TYPE_NOT_SUPPORTED_ERROR_FORMAT.format(typ=typ.format())) + raise AbiFormatError( + ABI_TYPE_NOT_SUPPORTED_ERROR_FORMAT.format(typ=typ.format())) assert previous_felt_input is not None, ( f"The array argument {input_desc['name']} of type felt* must be preceded " "by a length argument of type felt." @@ -332,8 +340,10 @@ def validate_arguments( current_inputs_ptr += previous_felt_input * typ_size else: - raise AbiFormatError(ABI_TYPE_NOT_SUPPORTED_ERROR_FORMAT.format(typ=typ.format())) - previous_felt_input = inputs[current_inputs_ptr - 1] if typ == TypeFelt() else None + raise AbiFormatError( + ABI_TYPE_NOT_SUPPORTED_ERROR_FORMAT.format(typ=typ.format())) + previous_felt_input = inputs[current_inputs_ptr - + 1] if typ == TypeFelt() else None assert ( len(inputs) == current_inputs_ptr @@ -445,6 +455,8 @@ def validate_call_function_args( ) if abi_entry["name"] == args.function: + if args.raw: + break validate_arguments( inputs=inputs, abi_entry=abi_entry, @@ -462,7 +474,8 @@ def parse_call_function_args(args: argparse.Namespace) -> CallFunction: Parses the arguments and validates that the function name is in the ABI. """ inputs = cast_to_felts(values=args.inputs) - validate_call_function_args(args=args, abi_entry_type="function", inputs=inputs) + validate_call_function_args( + args=args, abi_entry_type="function", inputs=inputs) return CallFunction( contract_address=parse_hex_arg(arg=args.address, arg_name="address"), @@ -476,7 +489,8 @@ def parse_call_l1_handler_args(args: argparse.Namespace) -> CallL1Handler: Parses the arguments and validates that the l1_handler name is in the ABI. """ inputs = cast_to_felts(values=args.inputs) - from_address = parse_hex_arg(arg=args.from_address, arg_name="from_address") + from_address = parse_hex_arg( + arg=args.from_address, arg_name="from_address") validate_call_function_args( args=args, abi_entry_type="l1_handler", inputs=[from_address] + inputs ) @@ -503,7 +517,8 @@ def parse_invoke_tx_args(args: argparse.Namespace) -> InvokeFunctionArgs: def parse_declare_tx_args(args: argparse.Namespace) -> DeclareArgs: validate_max_fee(max_fee=args.max_fee) - sender = parse_hex_arg(arg=args.sender, arg_name="sender") if args.sender is not None else None + sender = parse_hex_arg( + arg=args.sender, arg_name="sender") if args.sender is not None else None return DeclareArgs( sender=sender, signature=cast_to_felts(values=args.signature), @@ -708,7 +723,8 @@ async def simulate_transaction(args: argparse.Namespace, tx: AccountTransaction) async def simulate_or_estimate_fee(args: argparse.Namespace, tx: AccountTransaction): - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) if args.simulate: await simulate_transaction(args=args, tx=tx) else: @@ -725,7 +741,8 @@ async def declare(args: argparse.Namespace, command_args: List[str]): valid signature must be provided as arguments. """ - parser = argparse.ArgumentParser(description="Sends a declare transaction to StarkNet.") + parser = argparse.ArgumentParser( + description="Sends a declare transaction to StarkNet.") add_declare_tx_arguments(parser=parser) parser.parse_args(command_args, namespace=args) declare_tx_args = parse_declare_tx_args(args=args) @@ -771,46 +788,10 @@ async def declare(args: argparse.Namespace, command_args: List[str]): async def deploy(args, command_args): - parser = argparse.ArgumentParser(description="Deploys a contract to StarkNet.") - parser.add_argument( - "--salt", - type=str, - help=( - "An optional salt controlling where the contract will be deployed. " - "The contract deployment address is determined by the hash " - "of contract, salt and caller. " - "If the salt is not supplied, the contract will be deployed with a random salt." - ), - ) - parser.add_argument( - "--inputs", type=str, nargs="*", default=[], help="The inputs to the constructor." - ) - parser.add_argument( - "--token", type=str, help="Used for deploying contracts in Alpha MainNet.", required=False - ) - parser.add_argument("--class_hash", type=str, help="The class hash of the deployed contract.") - parser.add_argument( - "--nonce", - type=int, - help=( - "Used for explicitly specifying the transaction nonce. " - "If not specified, the current nonce of the account contract " - "(as returned from StarkNet) will be used." - ), - ) - parser.add_argument( - "--max_fee", type=int, help="The maximal fee to be paid for the deployment." - ) - parser.add_argument( - "--contract", type=argparse.FileType("r"), help="The contract class to deploy." - ) - parser.add_argument( - "--deploy_from_zero", - action="store_true", - help="Use 0 instead of the deployer address for the contract address computation.", - ) + parser = argparse.ArgumentParser( + description="Deploys a contract to StarkNet.") + add_deploy_tx_arguments(parser=parser) parser.parse_args(command_args, namespace=args) - has_wallet = get_wallet_provider(args=args) is not None if has_wallet: assert args.contract is None, ( @@ -846,6 +827,8 @@ async def deploy_tx(args): if abi_entry["type"] == "constructor": try: + if args.raw: + break validate_arguments( inputs=inputs, abi_entry=abi_entry, @@ -857,7 +840,8 @@ async def deploy_tx(args): f"Failed to parse the contract ABI: {abi_error}" ) from abi_error else: - assert len(inputs) == 0, "--inputs cannot be specified for contracts without a constructor." + assert len( + inputs) == 0, "--inputs cannot be specified for contracts without a constructor." tx = Deploy( contract_address_salt=salt, @@ -984,14 +968,16 @@ async def deploy_account(args: argparse.Namespace, command_args: List[str]): async def call(args: argparse.Namespace, command_args: List[str]): - parser = argparse.ArgumentParser(description="Calls a function on a StarkNet contract.") + parser = argparse.ArgumentParser( + description="Calls a function on a StarkNet contract.") add_call_function_arguments(parser=parser) add_block_identifier_arguments( parser=parser, block_role_description="be used as the context for the call operation" ) parser.parse_args(command_args, namespace=args) call_function_args = parse_call_function_args(args=args) - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) feeder_client = get_feeder_gateway_client(args) gateway_response = await feeder_client.call_contract( @@ -1001,7 +987,8 @@ async def call(args: argparse.Namespace, command_args: List[str]): async def invoke(args: argparse.Namespace, command_args: List[str]): - parser = argparse.ArgumentParser(description="Sends an invoke transaction to StarkNet.") + parser = argparse.ArgumentParser( + description="Sends an invoke transaction to StarkNet.") add_invoke_tx_arguments(parser=parser) parser.parse_args(command_args, namespace=args) invoke_tx_args = parse_invoke_tx_args(args=args) @@ -1060,7 +1047,8 @@ async def invoke(args: argparse.Namespace, command_args: List[str]): def print_invoke_tx(tx: InvokeFunction, chain_id: int): sn_config_dict = StarknetGeneralConfig().dump() - sn_config_dict["starknet_os_config"]["chain_id"] = StarknetChainId(chain_id).name + sn_config_dict["starknet_os_config"]["chain_id"] = StarknetChainId( + chain_id).name sn_config = StarknetGeneralConfig.load(sn_config_dict) tx_hash = tx.calculate_hash(sn_config) out_dict = { @@ -1071,14 +1059,16 @@ def print_invoke_tx(tx: InvokeFunction, chain_id: int): async def estimate_message_fee(args: argparse.Namespace, command_args: List[str]): - parser = argparse.ArgumentParser(description="Estimates the fee of an L1-to-L2 message.") + parser = argparse.ArgumentParser( + description="Estimates the fee of an L1-to-L2 message.") add_block_identifier_arguments( parser=parser, block_role_description="be used as the context for the call operation" ) add_call_l1_handler_arguments(parser=parser) parser.parse_args(command_args, namespace=args) - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) call_l1_handler = parse_call_l1_handler_args(args=args) feeder_client = get_feeder_gateway_client(args=args) @@ -1133,7 +1123,8 @@ async def tx_status(args, command_args): else: addr_str, path = addr_and_path_split addr = parse_hex_arg(arg=addr_str, arg_name="address") - contracts[addr] = Program.load(data=json.load(open(path.strip()))["program"]) + contracts[addr] = Program.load( + data=json.load(open(path.strip()))["program"]) error_message = reconstruct_starknet_traceback( contracts=contracts, traceback_txt=error_message ) @@ -1160,7 +1151,8 @@ async def get_transaction(args, command_args): async def get_transaction_trace(args, command_args): - parser = argparse.ArgumentParser(description="Outputs the transaction trace given its hash.") + parser = argparse.ArgumentParser( + description="Outputs the transaction trace given its hash.") parser.add_argument( "--hash", type=str, required=True, help="The hash of the transaction to query." ) @@ -1172,7 +1164,8 @@ async def get_transaction_trace(args, command_args): async def get_transaction_receipt(args, command_args): - parser = argparse.ArgumentParser(description="Outputs the transaction receipt given its hash.") + parser = argparse.ArgumentParser( + description="Outputs the transaction receipt given its hash.") parser.add_argument( "--hash", type=str, required=True, help="The hash of the transaction to query." ) @@ -1228,8 +1221,10 @@ async def get_block_traces(args, command_args): async def get_state_update(args, command_args): - parser = argparse.ArgumentParser(description=("Outputs the state update of a given block")) - add_block_identifier_arguments(parser=parser, block_role_description="display") + parser = argparse.ArgumentParser(description=( + "Outputs the state update of a given block")) + add_block_identifier_arguments( + parser=parser, block_role_description="display") parser.parse_args(command_args, namespace=args) args.block_hash, args.block_number = parse_block_identifiers( @@ -1255,14 +1250,17 @@ async def get_code(args, command_args): parser.add_argument( "--contract_address", type=str, help="The address of the contract.", required=True ) - add_block_identifier_arguments(parser=parser, block_role_description="extract information from") + add_block_identifier_arguments( + parser=parser, block_role_description="extract information from") parser.parse_args(command_args, namespace=args) - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) feeder_gateway_client = get_feeder_gateway_client(args) code = await feeder_gateway_client.get_code( - contract_address=parse_hex_arg(arg=args.contract_address, arg_name="contract address"), + contract_address=parse_hex_arg( + arg=args.contract_address, arg_name="contract address"), block_hash=args.block_hash, block_number=args.block_number, ) @@ -1294,14 +1292,17 @@ async def get_full_contract(args, command_args): parser.add_argument( "--contract_address", type=str, help="The address of the contract.", required=True ) - add_block_identifier_arguments(parser=parser, block_role_description="extract information from") + add_block_identifier_arguments( + parser=parser, block_role_description="extract information from") parser.parse_args(command_args, namespace=args) - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) feeder_gateway_client = get_feeder_gateway_client(args) contract_class = await feeder_gateway_client.get_full_contract( - contract_address=parse_hex_arg(arg=args.contract_address, arg_name="contract address"), + contract_address=parse_hex_arg( + arg=args.contract_address, arg_name="contract address"), block_hash=args.block_hash, block_number=args.block_number, ) @@ -1318,14 +1319,17 @@ async def get_class_hash_at(args, command_args): parser.add_argument( "--contract_address", type=str, help="The address of the contract.", required=True ) - add_block_identifier_arguments(parser=parser, block_role_description="extract information from") + add_block_identifier_arguments( + parser=parser, block_role_description="extract information from") parser.parse_args(command_args, namespace=args) - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) feeder_gateway_client = get_feeder_gateway_client(args) class_hash = await feeder_gateway_client.get_class_hash_at( - contract_address=parse_hex_arg(arg=args.contract_address, arg_name="contract address"), + contract_address=parse_hex_arg( + arg=args.contract_address, arg_name="contract address"), block_hash=args.block_hash, block_number=args.block_number, ) @@ -1333,7 +1337,8 @@ async def get_class_hash_at(args, command_args): async def get_contract_addresses(args, command_args): - argparse.ArgumentParser(description="Outputs the addresses of the StarkNet system contracts.") + argparse.ArgumentParser( + description="Outputs the addresses of the StarkNet system contracts.") feeder_gateway_client = get_feeder_gateway_client(args) contract_addresses = await feeder_gateway_client.get_contract_addresses() @@ -1350,14 +1355,17 @@ async def get_nonce(args, command_args): parser.add_argument( "--contract_address", type=str, help="The address of the contract.", required=True ) - add_block_identifier_arguments(parser=parser, block_role_description="extract information from") + add_block_identifier_arguments( + parser=parser, block_role_description="extract information from") parser.parse_args(command_args, namespace=args) - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) feeder_gateway_client = get_feeder_gateway_client(args) nonce = await feeder_gateway_client.get_nonce( - contract_address=parse_hex_arg(args.contract_address, "contract_address"), + contract_address=parse_hex_arg( + args.contract_address, "contract_address"), block_hash=args.block_hash, block_number=args.block_number, ) @@ -1377,15 +1385,18 @@ async def get_storage_at(args, command_args): parser.add_argument( "--key", type=int, help="The position in the contract's storage.", required=True ) - add_block_identifier_arguments(parser=parser, block_role_description="extract information from") + add_block_identifier_arguments( + parser=parser, block_role_description="extract information from") parser.parse_args(command_args, namespace=args) - args.block_hash, args.block_number = parse_block_identifiers(args.block_hash, args.block_number) + args.block_hash, args.block_number = parse_block_identifiers( + args.block_hash, args.block_number) feeder_gateway_client = get_feeder_gateway_client(args) print( await feeder_gateway_client.get_storage_at( - contract_address=parse_hex_arg(arg=args.contract_address, arg_name="contract address"), + contract_address=parse_hex_arg( + arg=args.contract_address, arg_name="contract address"), key=args.key, block_hash=args.block_hash, block_number=args.block_number, @@ -1482,6 +1493,9 @@ def add_call_function_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--inputs", type=str, nargs="*", default=[], help="The inputs to the invoked function." ) + parser.add_argument( + "--raw", action="store_true", help="Function inputs are in raw format." + ) def add_call_l1_handler_arguments(parser: argparse.ArgumentParser): @@ -1524,7 +1538,8 @@ def add_block_identifier_arguments( parser: argparse.ArgumentParser, block_role_description: str, with_block_prefix: bool = True ): identifier_prefix = "block_" if with_block_prefix else "" - block_identifier_parser_group = parser.add_mutually_exclusive_group(required=False) + block_identifier_parser_group = parser.add_mutually_exclusive_group( + required=False) block_identifier_parser_group.add_argument( f"--{identifier_prefix}hash", type=str, @@ -1539,6 +1554,53 @@ def add_block_identifier_arguments( ) +def add_deploy_tx_arguments(parser: argparse.ArgumentParser): + """ + Adds the arguments: salt, inputs, raw, token, class_hash, nonce, max_fee, contract, deploy_from_zero. + """ + parser.add_argument( + "--salt", + type=str, + help=( + "An optional salt controlling where the contract will be deployed. " + "The contract deployment address is determined by the hash " + "of contract, salt and caller. " + "If the salt is not supplied, the contract will be deployed with a random salt." + ), + ) + parser.add_argument( + "--inputs", type=str, nargs="*", default=[], help="The inputs to the constructor." + ) + parser.add_argument( + "--raw", action="store_true", help="Constructor inputs are in raw format." + ) + parser.add_argument( + "--token", type=str, help="Used for deploying contracts in Alpha MainNet.", required=False + ) + parser.add_argument("--class_hash", type=str, + help="The class hash of the deployed contract.") + parser.add_argument( + "--nonce", + type=int, + help=( + "Used for explicitly specifying the transaction nonce. " + "If not specified, the current nonce of the account contract " + "(as returned from StarkNet) will be used." + ), + ) + parser.add_argument( + "--max_fee", type=int, help="The maximal fee to be paid for the deployment." + ) + parser.add_argument( + "--contract", type=argparse.FileType("r"), help="The contract class to deploy." + ) + parser.add_argument( + "--deploy_from_zero", + action="store_true", + help="Use 0 instead of the deployer address for the contract address computation.", + ) + + async def main(): subparsers = { "call": call, @@ -1563,9 +1625,12 @@ async def main(): "new_account": new_account, "tx_status": tx_status, } - parser = argparse.ArgumentParser(description="A tool to communicate with StarkNet.") - parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}") - parser.add_argument("--network", type=str, help="The name of the StarkNet network.") + parser = argparse.ArgumentParser( + description="A tool to communicate with StarkNet.") + parser.add_argument("-v", "--version", action="version", + version=f"%(prog)s {__version__}") + parser.add_argument("--network", type=str, + help="The name of the StarkNet network.") parser.add_argument( "--network_id", type=str, @@ -1617,7 +1682,8 @@ async def main(): help="Print the full Python error trace in case of an internal error.", ) - parser.add_argument("--gateway_url", type=str, help="The URL of a StarkNet gateway.") + parser.add_argument("--gateway_url", type=str, + help="The URL of a StarkNet gateway.") parser.add_argument( "--feeder_gateway_url", type=str, help="The URL of a StarkNet feeder gateway." ) From 75a721e9ffa5027ed295d92cc68115ae10194226 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20B=C3=A4ttig?= Date: Mon, 12 Dec 2022 14:04:54 +0100 Subject: [PATCH 4/4] added raw flag --- src/starkware/starknet/cli/starknet_cli.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/starkware/starknet/cli/starknet_cli.py b/src/starkware/starknet/cli/starknet_cli.py index ddca7c6a..a512eb1a 100755 --- a/src/starkware/starknet/cli/starknet_cli.py +++ b/src/starkware/starknet/cli/starknet_cli.py @@ -1044,6 +1044,9 @@ def validate_call_function_args( ) if abi_entry["name"] == args.function: + if args.raw: + break + validate_arguments( inputs=inputs, abi_entry=abi_entry, @@ -1392,7 +1395,7 @@ def add_declare_tx_arguments(parser: argparse.ArgumentParser): def add_call_function_arguments(parser: argparse.ArgumentParser): """ - Adds the arguments: address, abi, function, inputs. + Adds the arguments: address, abi, function, inputs, raw. """ parser.add_argument( "--address", type=str, required=True, help="The address of the invoked contract." @@ -1406,6 +1409,9 @@ def add_call_function_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--inputs", type=str, nargs="*", default=[], help="The inputs to the invoked function." ) + parser.add_argument( + "--raw", action="store_true", help="Function inputs are in raw format." + ) def add_call_l1_handler_arguments(parser: argparse.ArgumentParser):