From 544523567d16d22c5c423fcccabedd4715e391a6 Mon Sep 17 00:00:00 2001 From: Lior Goldberg Date: Sun, 13 Jun 2021 15:57:01 +0300 Subject: [PATCH] Cairo v0.2.0. --- .gitignore | 1 + Dockerfile | 2 +- README.md | 4 +- scripts/requirements-deps.json | 234 ++- scripts/requirements-gen.txt | 3 + scripts/requirements.txt | 48 +- src/CMakeLists.txt | 1 + src/cmake_utils/gen_py_lib.py | 15 +- src/cmake_utils/gen_python_exe.py | 1 + src/cmake_utils/pip_rules.cmake | 8 +- src/cmake_utils/python_rules.cmake | 3 + src/services/CMakeLists.txt | 2 + src/services/everest/CMakeLists.txt | 2 + src/services/everest/api/CMakeLists.txt | 2 + .../everest/api/feeder_gateway/CMakeLists.txt | 9 + .../feeder_gateway/feeder_gateway_client.py | 19 + .../everest/api/gateway/CMakeLists.txt | 28 + .../everest/api/gateway/gateway_client.py | 26 + .../everest/api/gateway/transaction.py | 19 + .../everest/api/gateway/transaction_type.py | 9 + .../everest/definitions/CMakeLists.txt | 10 + src/services/everest/definitions/fields.py | 9 + src/services/external_api/CMakeLists.txt | 12 + src/services/external_api/base_client.py | 120 ++ src/services/external_api/has_uri_prefix.py | 24 + src/starkware/CMakeLists.txt | 4 +- .../cairo/apps/starkex2_0/CMakeLists.txt | 1 + .../cairo/apps/starkex2_0/__start__.cairo | 10 + .../starkex2_0/common/cairo_builtins.cairo | 7 - .../cairo/apps/starkex2_0/common/dict.cairo | 2 +- .../cairo/apps/starkex2_0/execute_batch.cairo | 3 +- .../apps/starkex2_0/execute_transfer.cairo | 3 +- .../apps/starkex2_0/hash_vault_ptr_dict.cairo | 6 +- .../cairo/apps/starkex2_0/main.cairo | 7 +- .../starkex2_0/starkex2_0_program_test.py | 1 + src/starkware/cairo/bootloader/CMakeLists.txt | 2 +- .../cairo/bootloader/fact_topology.py | 6 +- src/starkware/cairo/common/CMakeLists.txt | 21 + .../cairo/common/cairo_builtins.cairo | 7 - .../cairo/common/cairo_function_runner.py | 150 ++ src/starkware/cairo/common/find_element.cairo | 49 +- src/starkware/cairo/common/hash_state.cairo | 43 +- src/starkware/cairo/common/invoke.cairo | 21 + src/starkware/cairo/common/math.cairo | 62 +- src/starkware/cairo/common/math_cmp.cairo | 76 + src/starkware/cairo/common/math_utils.py | 8 + src/starkware/cairo/common/memcpy.cairo | 6 +- src/starkware/cairo/common/registers.cairo | 12 +- src/starkware/cairo/common/set.cairo | 58 + .../cairo/common/small_merkle_tree_test.py | 50 - src/starkware/cairo/common/squash_dict.cairo | 15 +- src/starkware/cairo/common/structs.py | 115 ++ src/starkware/cairo/common/uint256.cairo | 83 + src/starkware/cairo/lang/CMakeLists.txt | 55 +- src/starkware/cairo/lang/VERSION | 2 +- .../cairo/lang/builtins/CMakeLists.txt | 4 +- .../checkpoints/checkpoints_builtin_runner.py | 37 - .../lang/builtins/checkpoints/instance_def.py | 11 - .../cairo/lang/compiler/CMakeLists.txt | 14 + .../cairo/lang/compiler/assembler_test.py | 6 +- .../compiler/ast/ast_objects_test_utils.py | 24 + .../cairo/lang/compiler/ast/code_elements.py | 23 +- src/starkware/cairo/lang/compiler/ast/expr.py | 57 +- .../cairo/lang/compiler/ast/expr_func_call.py | 25 + .../cairo/lang/compiler/ast/rvalue.py | 14 + .../cairo/lang/compiler/ast/visitor.py | 46 +- .../cairo/lang/compiler/ast_objects_test.py | 62 +- src/starkware/cairo/lang/compiler/cairo.ebnf | 16 +- .../cairo/lang/compiler/cairo_compile.py | 197 ++- src/starkware/cairo/lang/compiler/conftest.py | 6 + .../cairo/lang/compiler/const_expr_checker.py | 13 +- .../lang/compiler/expression_evaluator.py | 10 +- .../lang/compiler/expression_transformer.py | 51 +- src/starkware/cairo/lang/compiler/fields.py | 2 +- .../lang/compiler/identifier_definition.py | 9 + .../cairo/lang/compiler/identifier_manager.py | 32 +- .../lang/compiler/identifier_manager_test.py | 2 +- .../cairo/lang/compiler/identifier_utils.py | 47 +- .../lang/compiler/identifier_utils_test.py | 31 +- .../cairo/lang/compiler/import_loader.py | 11 +- .../cairo/lang/compiler/import_loader_test.py | 61 +- .../cairo/lang/compiler/instruction_test.py | 6 +- .../cairo/lang/compiler/offset_reference.py | 47 +- .../lang/compiler/offset_reference_test.py | 57 +- src/starkware/cairo/lang/compiler/parser.py | 2 + .../cairo/lang/compiler/parser_errors_test.py | 36 +- .../cairo/lang/compiler/parser_test.py | 106 +- .../cairo/lang/compiler/parser_test_utils.py | 15 + .../cairo/lang/compiler/parser_transformer.py | 42 +- .../preprocessor/compound_expressions_test.py | 19 +- .../lang/compiler/preprocessor/conftest.py | 3 + .../preprocessor/default_pass_manager.py | 96 ++ .../compiler/preprocessor/dependency_graph.py | 138 +- .../preprocessor/dependency_graph_test.py | 118 +- .../preprocessor/identifier_aware_visitor.py | 32 +- .../identifier_aware_visitor_test.py | 23 + .../preprocessor/identifier_collector.py | 24 +- .../preprocessor/identifier_collector_test.py | 9 +- .../compiler/preprocessor/pass_manager.py | 104 ++ .../compiler/preprocessor/preprocess_codes.py | 25 + .../compiler/preprocessor/preprocessor.py | 415 +++-- .../preprocessor/preprocessor_test.py | 688 +++++++- .../preprocessor/preprocessor_test_utils.py | 40 +- .../preprocessor/unique_labels_test.py | 8 + src/starkware/cairo/lang/compiler/program.py | 31 +- .../cairo/lang/compiler/references.py | 14 +- .../lang/compiler/resolve_search_result.py | 45 + .../compiler/resolve_search_result_test.py | 36 + .../lang/compiler/substitute_identifiers.py | 37 +- .../cairo/lang/compiler/type_casts.py | 46 +- .../cairo/lang/compiler/type_casts_test.py | 8 +- .../cairo/lang/compiler/type_system.py | 58 + .../lang/compiler/type_system_visitor.py | 245 ++- .../lang/compiler/type_system_visitor_test.py | 320 +++- .../cairo/lang/ide/vscode-cairo/package.json | 2 +- src/starkware/cairo/lang/instances.py | 16 +- src/starkware/cairo/lang/lang.cmake | 55 + .../cairo/lang/package_test/run_test.sh | 7 + src/starkware/cairo/lang/setup.py | 14 +- src/starkware/cairo/lang/tracer/profile.py | 3 + .../cairo/lang/tracer/tracer_data.py | 9 +- .../cairo/lang/tracer/tracer_data_test.py | 11 +- src/starkware/cairo/lang/vm/CMakeLists.txt | 5 +- src/starkware/cairo/lang/vm/cairo_pie.py | 31 +- src/starkware/cairo/lang/vm/cairo_runner.py | 96 +- .../cairo/lang/vm/cairo_runner_test.py | 4 +- src/starkware/cairo/lang/vm/crypto.py | 1 + src/starkware/cairo/lang/vm/memory_dict.py | 81 +- .../cairo/lang/vm/memory_dict_test.py | 108 +- .../cairo/lang/vm/memory_segments.py | 34 +- .../lang/vm/output_builtin_runner_test.py | 2 +- src/starkware/cairo/lang/vm/vm.py | 35 +- src/starkware/cairo/lang/vm/vm_consts.py | 46 +- src/starkware/cairo/lang/vm/vm_consts_test.py | 30 +- src/starkware/cairo/lang/vm/vm_test.py | 18 + src/starkware/cairo/sharp/config.json | 2 +- .../starkware/crypto/signature/signature.py | 6 +- src/starkware/python/CMakeLists.txt | 3 + src/starkware/python/async_subprocess.py | 18 + src/starkware/python/random_test.py | 135 ++ src/starkware/python/utils.py | 13 +- src/starkware/python/utils_test.py | 18 +- src/starkware/starknet/CMakeLists.txt | 8 + .../starknet/apps/amm_sample/amm_sample.cairo | 141 ++ src/starkware/starknet/cli/CMakeLists.txt | 30 + src/starkware/starknet/cli/starknet_cli.py | 259 +++ .../starknet/compiler/CMakeLists.txt | 53 + src/starkware/starknet/compiler/__init__.py | 0 .../starknet/compiler/calldata_parser.py | 42 + .../starknet/compiler/calldata_parser_test.py | 59 + src/starkware/starknet/compiler/compile.py | 106 ++ src/starkware/starknet/compiler/conftest.py | 3 + .../compiler/starknet_pass_manager.py | 38 + .../compiler/starknet_preprocessor.py | 394 +++++ .../compiler/starknet_preprocessor_test.py | 244 +++ .../starknet/compiler/storage_var.py | 218 +++ .../starknet/compiler/storage_var_test.py | 194 +++ src/starkware/starknet/compiler/test_utils.py | 51 + src/starkware/starknet/core/CMakeLists.txt | 1 + .../starknet/core/storage/CMakeLists.txt | 20 + .../starknet/core/storage/__init__.py | 0 .../starknet/core/storage/storage.cairo | 36 + .../starknet/core/storage/storage_test.py | 72 + .../starknet/definitions/CMakeLists.txt | 15 + .../starknet/definitions/__init__.py | 0 .../starknet/definitions/constants.py | 16 + .../starknet/definitions/error_codes.py | 23 + src/starkware/starknet/definitions/fields.py | 72 + .../starknet/definitions/transaction_type.py | 6 + src/starkware/starknet/public/CMakeLists.txt | 24 + src/starkware/starknet/public/__init__.py | 0 src/starkware/starknet/public/abi.py | 26 + src/starkware/starknet/public/abi_test.py | 7 + src/starkware/starknet/scripts/CMakeLists.txt | 10 + src/starkware/starknet/scripts/starknet | 12 + .../starknet/scripts/starknet-compile | 10 + .../starknet/security/CMakeLists.txt | 46 + src/starkware/starknet/security/__init__.py | 0 .../starknet/security/hints_whitelist.py | 8 + .../security/latest_whitelist_test.py | 41 + .../starknet/security/secure_hints.py | 128 ++ .../starknet/security/secure_hints_test.py | 112 ++ .../starknet/security/starknet_common.cairo | 14 + .../starknet/security/whitelists/latest.json | 1394 +++++++++++++++++ .../starknet/services/CMakeLists.txt | 1 + .../starknet/services/api/CMakeLists.txt | 16 + .../starknet/services/api/__init__.py | 0 .../services/api/contract_definition.py | 40 + .../api/feeder_gateway/CMakeLists.txt | 10 + .../feeder_gateway/feeder_gateway_client.py | 42 + .../services/api/gateway/CMakeLists.txt | 27 + .../services/api/gateway/gateway_client.py | 15 + .../services/api/gateway/transaction.py | 116 ++ src/starkware/starkware_utils/CMakeLists.txt | 29 + .../starkware_utils/custom_raising_dict.py | 71 + .../starkware_utils/error_handling.py | 209 +++ .../starkware_utils/field_validators.py | 257 +++ .../marshmallow_dataclass_fields.py | 164 ++ src/starkware/starkware_utils/serializable.py | 116 ++ .../starkware_utils/validated_dataclass.py | 331 ++++ .../starkware_utils/validated_fields.py | 206 +++ 201 files changed, 10295 insertions(+), 1291 deletions(-) create mode 100644 src/services/CMakeLists.txt create mode 100644 src/services/everest/CMakeLists.txt create mode 100644 src/services/everest/api/CMakeLists.txt create mode 100644 src/services/everest/api/feeder_gateway/CMakeLists.txt create mode 100644 src/services/everest/api/feeder_gateway/feeder_gateway_client.py create mode 100644 src/services/everest/api/gateway/CMakeLists.txt create mode 100644 src/services/everest/api/gateway/gateway_client.py create mode 100644 src/services/everest/api/gateway/transaction.py create mode 100644 src/services/everest/api/gateway/transaction_type.py create mode 100644 src/services/everest/definitions/CMakeLists.txt create mode 100644 src/services/everest/definitions/fields.py create mode 100644 src/services/external_api/CMakeLists.txt create mode 100644 src/services/external_api/base_client.py create mode 100644 src/services/external_api/has_uri_prefix.py create mode 100644 src/starkware/cairo/apps/starkex2_0/__start__.cairo create mode 100644 src/starkware/cairo/common/cairo_function_runner.py create mode 100644 src/starkware/cairo/common/invoke.cairo create mode 100644 src/starkware/cairo/common/math_cmp.cairo create mode 100644 src/starkware/cairo/common/set.cairo delete mode 100644 src/starkware/cairo/common/small_merkle_tree_test.py create mode 100644 src/starkware/cairo/common/structs.py create mode 100644 src/starkware/cairo/common/uint256.cairo delete mode 100644 src/starkware/cairo/lang/builtins/checkpoints/checkpoints_builtin_runner.py delete mode 100644 src/starkware/cairo/lang/builtins/checkpoints/instance_def.py create mode 100644 src/starkware/cairo/lang/compiler/ast/ast_objects_test_utils.py create mode 100644 src/starkware/cairo/lang/compiler/ast/expr_func_call.py create mode 100644 src/starkware/cairo/lang/compiler/conftest.py create mode 100644 src/starkware/cairo/lang/compiler/parser_test_utils.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/default_pass_manager.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor_test.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/pass_manager.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/preprocess_codes.py create mode 100644 src/starkware/cairo/lang/compiler/resolve_search_result.py create mode 100644 src/starkware/cairo/lang/compiler/resolve_search_result_test.py create mode 100644 src/starkware/cairo/lang/compiler/type_system.py create mode 100644 src/starkware/cairo/lang/lang.cmake create mode 100644 src/starkware/python/async_subprocess.py create mode 100644 src/starkware/python/random_test.py create mode 100644 src/starkware/starknet/CMakeLists.txt create mode 100644 src/starkware/starknet/apps/amm_sample/amm_sample.cairo create mode 100644 src/starkware/starknet/cli/CMakeLists.txt create mode 100755 src/starkware/starknet/cli/starknet_cli.py create mode 100644 src/starkware/starknet/compiler/CMakeLists.txt create mode 100644 src/starkware/starknet/compiler/__init__.py create mode 100644 src/starkware/starknet/compiler/calldata_parser.py create mode 100644 src/starkware/starknet/compiler/calldata_parser_test.py create mode 100644 src/starkware/starknet/compiler/compile.py create mode 100644 src/starkware/starknet/compiler/conftest.py create mode 100644 src/starkware/starknet/compiler/starknet_pass_manager.py create mode 100644 src/starkware/starknet/compiler/starknet_preprocessor.py create mode 100644 src/starkware/starknet/compiler/starknet_preprocessor_test.py create mode 100644 src/starkware/starknet/compiler/storage_var.py create mode 100644 src/starkware/starknet/compiler/storage_var_test.py create mode 100644 src/starkware/starknet/compiler/test_utils.py create mode 100644 src/starkware/starknet/core/CMakeLists.txt create mode 100644 src/starkware/starknet/core/storage/CMakeLists.txt create mode 100644 src/starkware/starknet/core/storage/__init__.py create mode 100644 src/starkware/starknet/core/storage/storage.cairo create mode 100644 src/starkware/starknet/core/storage/storage_test.py create mode 100644 src/starkware/starknet/definitions/CMakeLists.txt create mode 100644 src/starkware/starknet/definitions/__init__.py create mode 100644 src/starkware/starknet/definitions/constants.py create mode 100644 src/starkware/starknet/definitions/error_codes.py create mode 100644 src/starkware/starknet/definitions/fields.py create mode 100644 src/starkware/starknet/definitions/transaction_type.py create mode 100644 src/starkware/starknet/public/CMakeLists.txt create mode 100644 src/starkware/starknet/public/__init__.py create mode 100644 src/starkware/starknet/public/abi.py create mode 100644 src/starkware/starknet/public/abi_test.py create mode 100644 src/starkware/starknet/scripts/CMakeLists.txt create mode 100755 src/starkware/starknet/scripts/starknet create mode 100755 src/starkware/starknet/scripts/starknet-compile create mode 100644 src/starkware/starknet/security/CMakeLists.txt create mode 100644 src/starkware/starknet/security/__init__.py create mode 100644 src/starkware/starknet/security/hints_whitelist.py create mode 100644 src/starkware/starknet/security/latest_whitelist_test.py create mode 100644 src/starkware/starknet/security/secure_hints.py create mode 100644 src/starkware/starknet/security/secure_hints_test.py create mode 100644 src/starkware/starknet/security/starknet_common.cairo create mode 100644 src/starkware/starknet/security/whitelists/latest.json create mode 100644 src/starkware/starknet/services/CMakeLists.txt create mode 100644 src/starkware/starknet/services/api/CMakeLists.txt create mode 100644 src/starkware/starknet/services/api/__init__.py create mode 100644 src/starkware/starknet/services/api/contract_definition.py create mode 100644 src/starkware/starknet/services/api/feeder_gateway/CMakeLists.txt create mode 100644 src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py create mode 100644 src/starkware/starknet/services/api/gateway/CMakeLists.txt create mode 100644 src/starkware/starknet/services/api/gateway/gateway_client.py create mode 100644 src/starkware/starknet/services/api/gateway/transaction.py create mode 100644 src/starkware/starkware_utils/CMakeLists.txt create mode 100644 src/starkware/starkware_utils/custom_raising_dict.py create mode 100644 src/starkware/starkware_utils/error_handling.py create mode 100644 src/starkware/starkware_utils/field_validators.py create mode 100644 src/starkware/starkware_utils/marshmallow_dataclass_fields.py create mode 100644 src/starkware/starkware_utils/serializable.py create mode 100644 src/starkware/starkware_utils/validated_dataclass.py create mode 100644 src/starkware/starkware_utils/validated_fields.py diff --git a/.gitignore b/.gitignore index f719f659..40e870ea 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /build/ __pycache__/ cairo-lang-*.zip +/.vscode/ diff --git a/Dockerfile b/Dockerfile index f22d6513..36b3a775 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,7 +20,7 @@ RUN src/starkware/cairo/lang/package_test/run_test.sh # Build the Visual Studio Code extension. WORKDIR /app/src/starkware/cairo/lang/ide/vscode-cairo -RUN npm install -g vsce +RUN npm install -g vsce@1.87.1 RUN npm install RUN vsce package diff --git a/README.md b/README.md index d2cf5518..fbf911a4 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ We recommend starting from [Setting up the environment](https://cairo-lang.org/d # Installation instructions You should be able to download the python package zip file directly from -[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.1.0) +[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.2.0) and install it using ``pip``. See [Setting up the environment](https://cairo-lang.org/docs/quickstart.html). @@ -54,7 +54,7 @@ Once the docker image is built, you can fetch the python package zip file using: ```bash > container_id=$(docker create cairo) -> docker cp ${container_id}:/app/cairo-lang-0.1.0.zip . +> docker cp ${container_id}:/app/cairo-lang-0.2.0.zip . > docker rm -v ${container_id} ``` diff --git a/scripts/requirements-deps.json b/scripts/requirements-deps.json index 0eb8c5e2..0506b04d 100644 --- a/scripts/requirements-deps.json +++ b/scripts/requirements-deps.json @@ -1,8 +1,61 @@ [ + { + "dependencies": [ + { + "installed_version": "3.0.1", + "key": "async-timeout", + "package_name": "async-timeout", + "required_version": ">=3.0,<4.0" + }, + { + "installed_version": "21.2.0", + "key": "attrs", + "package_name": "attrs", + "required_version": ">=17.3.0" + }, + { + "installed_version": "4.0.0", + "key": "chardet", + "package_name": "chardet", + "required_version": ">=2.0,<5.0" + }, + { + "installed_version": "5.1.0", + "key": "multidict", + "package_name": "multidict", + "required_version": ">=4.5,<7.0" + }, + { + "installed_version": "3.10.0.0", + "key": "typing-extensions", + "package_name": "typing-extensions", + "required_version": ">=3.6.5" + }, + { + "installed_version": "1.6.3", + "key": "yarl", + "package_name": "yarl", + "required_version": ">=1.0,<2.0" + } + ], + "package": { + "installed_version": "3.7.4.post0", + "key": "aiohttp", + "package_name": "aiohttp" + } + }, { "dependencies": [], "package": { - "installed_version": "20.2.0", + "installed_version": "3.0.1", + "key": "async-timeout", + "package_name": "async-timeout" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "21.2.0", "key": "attrs", "package_name": "attrs" } @@ -57,14 +110,14 @@ { "dependencies": [ { - "installed_version": "1.15.0", + "installed_version": "1.16.0", "key": "six", "package_name": "six", "required_version": ">=1.9.0" } ], "package": { - "installed_version": "0.16.0", + "installed_version": "0.17.0", "key": "ecdsa", "package_name": "ecdsa" } @@ -78,7 +131,7 @@ "required_version": ">=2.0.0,<3.0.0" }, { - "installed_version": "1.10.0", + "installed_version": "1.9.5", "key": "eth-utils", "package_name": "eth-utils", "required_version": ">=1.2.0,<2.0.0" @@ -129,7 +182,7 @@ "required_version": ">=0.1.2,<2" }, { - "installed_version": "1.10.0", + "installed_version": "1.9.5", "key": "eth-utils", "package_name": "eth-utils", "required_version": ">=1.3.0,<2" @@ -176,13 +229,13 @@ "required_version": ">=0.1.0-beta.4,<1.0.0" }, { - "installed_version": "1.10.0", + "installed_version": "1.9.5", "key": "eth-utils", "package_name": "eth-utils", "required_version": ">=1.0.0-beta.1,<2.0.0" }, { - "installed_version": "3.9.9", + "installed_version": "3.10.1", "key": "pycryptodome", "package_name": "pycryptodome", "required_version": ">=3.4.7,<4.0.0" @@ -203,7 +256,7 @@ "required_version": ">=2.2.1,<3.0.0" }, { - "installed_version": "1.10.0", + "installed_version": "1.9.5", "key": "eth-utils", "package_name": "eth-utils", "required_version": ">=1.3.0,<2.0.0" @@ -218,7 +271,7 @@ { "dependencies": [ { - "installed_version": "1.10.0", + "installed_version": "1.9.5", "key": "eth-utils", "package_name": "eth-utils", "required_version": ">=1.0.1,<2" @@ -262,7 +315,7 @@ "installed_version": "0.2.0", "key": "eth-hash", "package_name": "eth-hash", - "required_version": ">=0.3.1,<0.4.0" + "required_version": ">=0.1.0,<1.0.0" }, { "installed_version": "2.2.2", @@ -272,7 +325,7 @@ } ], "package": { - "installed_version": "1.10.0", + "installed_version": "1.9.5", "key": "eth-utils", "package_name": "eth-utils" } @@ -285,6 +338,14 @@ "package_name": "fastecdsa" } }, + { + "dependencies": [], + "package": { + "installed_version": "1.2", + "key": "frozendict", + "package_name": "frozendict" + } + }, { "dependencies": [], "package": { @@ -304,14 +365,20 @@ { "dependencies": [ { - "installed_version": "3.4.0", + "installed_version": "3.10.0.0", + "key": "typing-extensions", + "package_name": "typing-extensions", + "required_version": ">=3.6.4" + }, + { + "installed_version": "3.4.1", "key": "zipp", "package_name": "zipp", "required_version": ">=0.5" } ], "package": { - "installed_version": "2.0.0", + "installed_version": "4.3.1", "key": "importlib-metadata", "package_name": "importlib-metadata" } @@ -348,13 +415,13 @@ { "dependencies": [ { - "installed_version": "20.2.0", + "installed_version": "21.2.0", "key": "attrs", "package_name": "attrs", "required_version": ">=17.4.0" }, { - "installed_version": "2.0.0", + "installed_version": "4.3.1", "key": "importlib-metadata", "package_name": "importlib-metadata", "required_version": null @@ -366,13 +433,13 @@ "required_version": ">=0.14.0" }, { - "installed_version": "47.1.1", + "installed_version": "57.0.0", "key": "setuptools", "package_name": "setuptools", "required_version": null }, { - "installed_version": "1.15.0", + "installed_version": "1.16.0", "key": "six", "package_name": "six", "required_version": ">=1.11.0" @@ -403,7 +470,7 @@ { "dependencies": [], "package": { - "installed_version": "3.8.0", + "installed_version": "3.12.1", "key": "marshmallow", "package_name": "marshmallow" } @@ -411,7 +478,7 @@ { "dependencies": [ { - "installed_version": "3.8.0", + "installed_version": "3.12.1", "key": "marshmallow", "package_name": "marshmallow", "required_version": ">=3.0.0,<4.0" @@ -424,7 +491,7 @@ } ], "package": { - "installed_version": "8.1.0", + "installed_version": "8.4.1", "key": "marshmallow-dataclass", "package_name": "marshmallow-dataclass" } @@ -432,7 +499,7 @@ { "dependencies": [ { - "installed_version": "3.8.0", + "installed_version": "3.12.1", "key": "marshmallow", "package_name": "marshmallow", "required_version": ">=2.0.0" @@ -447,7 +514,7 @@ { "dependencies": [ { - "installed_version": "3.8.0", + "installed_version": "3.12.1", "key": "marshmallow", "package_name": "marshmallow", "required_version": ">=3.0.0rc6,<4.0.0" @@ -462,7 +529,7 @@ { "dependencies": [], "package": { - "installed_version": "1.1.0", + "installed_version": "1.2.1", "key": "mpmath", "package_name": "mpmath" } @@ -482,7 +549,7 @@ "required_version": null }, { - "installed_version": "1.15.0", + "installed_version": "1.16.0", "key": "six", "package_name": "six", "required_version": null @@ -500,6 +567,14 @@ "package_name": "multiaddr" } }, + { + "dependencies": [], + "package": { + "installed_version": "5.1.0", + "key": "multidict", + "package_name": "multidict" + } + }, { "dependencies": [], "package": { @@ -519,7 +594,7 @@ { "dependencies": [], "package": { - "installed_version": "1.19.2", + "installed_version": "1.20.3", "key": "numpy", "package_name": "numpy" } @@ -531,16 +606,10 @@ "key": "pyparsing", "package_name": "pyparsing", "required_version": ">=2.0.2" - }, - { - "installed_version": "1.15.0", - "key": "six", - "package_name": "six", - "required_version": null } ], "package": { - "installed_version": "20.4", + "installed_version": "20.9", "key": "packaging", "package_name": "packaging" } @@ -548,7 +617,7 @@ { "dependencies": [ { - "installed_version": "1.15.0", + "installed_version": "1.16.0", "key": "six", "package_name": "six", "required_version": ">=1.9.0" @@ -563,7 +632,7 @@ { "dependencies": [], "package": { - "installed_version": "20.1.1", + "installed_version": "21.1.2", "key": "pip", "package_name": "pip" } @@ -571,14 +640,14 @@ { "dependencies": [ { - "installed_version": "20.1.1", + "installed_version": "21.1.2", "key": "pip", "package_name": "pip", "required_version": ">=6.0.0" } ], "package": { - "installed_version": "1.0.0", + "installed_version": "2.0.0", "key": "pipdeptree", "package_name": "pipdeptree" } @@ -586,7 +655,7 @@ { "dependencies": [ { - "installed_version": "2.0.0", + "installed_version": "4.3.1", "key": "importlib-metadata", "package_name": "importlib-metadata", "required_version": ">=0.12" @@ -601,14 +670,14 @@ { "dependencies": [ { - "installed_version": "1.15.0", + "installed_version": "1.16.0", "key": "six", "package_name": "six", "required_version": ">=1.9" } ], "package": { - "installed_version": "3.15.1", + "installed_version": "3.17.1", "key": "protobuf", "package_name": "protobuf" } @@ -616,7 +685,7 @@ { "dependencies": [], "package": { - "installed_version": "1.9.0", + "installed_version": "1.10.0", "key": "py", "package_name": "py" } @@ -624,7 +693,7 @@ { "dependencies": [], "package": { - "installed_version": "3.9.9", + "installed_version": "3.10.1", "key": "pycryptodome", "package_name": "pycryptodome" } @@ -648,13 +717,13 @@ { "dependencies": [ { - "installed_version": "20.2.0", + "installed_version": "21.2.0", "key": "attrs", "package_name": "attrs", - "required_version": ">=17.4.0" + "required_version": ">=19.2.0" }, { - "installed_version": "2.0.0", + "installed_version": "4.3.1", "key": "importlib-metadata", "package_name": "importlib-metadata", "required_version": ">=0.12" @@ -666,7 +735,7 @@ "required_version": null }, { - "installed_version": "20.4", + "installed_version": "20.9", "key": "packaging", "package_name": "packaging", "required_version": null @@ -675,23 +744,23 @@ "installed_version": "0.13.1", "key": "pluggy", "package_name": "pluggy", - "required_version": ">=0.12,<1.0" + "required_version": ">=0.12,<1.0.0a1" }, { - "installed_version": "1.9.0", + "installed_version": "1.10.0", "key": "py", "package_name": "py", "required_version": ">=1.8.2" }, { - "installed_version": "0.10.1", + "installed_version": "0.10.2", "key": "toml", "package_name": "toml", "required_version": null } ], "package": { - "installed_version": "6.1.1", + "installed_version": "6.2.4", "key": "pytest", "package_name": "pytest" } @@ -717,7 +786,7 @@ "required_version": ">=2.5,<3" }, { - "installed_version": "1.26.3", + "installed_version": "1.26.5", "key": "urllib3", "package_name": "urllib3", "required_version": ">=1.21.1,<1.27" @@ -732,7 +801,7 @@ { "dependencies": [ { - "installed_version": "1.10.0", + "installed_version": "1.9.5", "key": "eth-utils", "package_name": "eth-utils", "required_version": ">=1.0.2,<2" @@ -747,7 +816,7 @@ { "dependencies": [], "package": { - "installed_version": "47.1.1", + "installed_version": "57.0.0", "key": "setuptools", "package_name": "setuptools" } @@ -755,7 +824,7 @@ { "dependencies": [], "package": { - "installed_version": "1.15.0", + "installed_version": "1.16.0", "key": "six", "package_name": "six" } @@ -763,14 +832,14 @@ { "dependencies": [ { - "installed_version": "1.1.0", + "installed_version": "1.2.1", "key": "mpmath", "package_name": "mpmath", "required_version": ">=0.19" } ], "package": { - "installed_version": "1.6.2", + "installed_version": "1.8", "key": "sympy", "package_name": "sympy" } @@ -778,7 +847,7 @@ { "dependencies": [], "package": { - "installed_version": "0.10.1", + "installed_version": "0.10.2", "key": "toml", "package_name": "toml" } @@ -794,7 +863,15 @@ { "dependencies": [], "package": { - "installed_version": "3.7.4.3", + "installed_version": "2.12.0", + "key": "typeguard", + "package_name": "typeguard" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "3.10.0.0", "key": "typing-extensions", "package_name": "typing-extensions" } @@ -808,7 +885,7 @@ "required_version": ">=0.3.0" }, { - "installed_version": "3.7.4.3", + "installed_version": "3.10.0.0", "key": "typing-extensions", "package_name": "typing-extensions", "required_version": ">=3.7.4" @@ -823,7 +900,7 @@ { "dependencies": [], "package": { - "installed_version": "1.26.3", + "installed_version": "1.26.5", "key": "urllib3", "package_name": "urllib3" } @@ -863,7 +940,7 @@ "required_version": ">=2.0.0,<3.0.0" }, { - "installed_version": "1.10.0", + "installed_version": "1.9.5", "key": "eth-utils", "package_name": "eth-utils", "required_version": ">=1.9.5,<2.0.0" @@ -893,7 +970,7 @@ "required_version": ">=1.1.6,<2.0.0" }, { - "installed_version": "3.15.1", + "installed_version": "3.17.1", "key": "protobuf", "package_name": "protobuf", "required_version": ">=3.10.0,<4" @@ -905,7 +982,7 @@ "required_version": ">=2.16.0,<3.0.0" }, { - "installed_version": "3.7.4.3", + "installed_version": "3.10.0.0", "key": "typing-extensions", "package_name": "typing-extensions", "required_version": ">=3.7.4.1,<4" @@ -918,7 +995,7 @@ } ], "package": { - "installed_version": "5.16.0", + "installed_version": "5.19.0", "key": "web3", "package_name": "web3" } @@ -934,15 +1011,42 @@ { "dependencies": [], "package": { - "installed_version": "0.34.2", + "installed_version": "0.36.2", "key": "wheel", "package_name": "wheel" } }, + { + "dependencies": [ + { + "installed_version": "2.10", + "key": "idna", + "package_name": "idna", + "required_version": ">=2.0" + }, + { + "installed_version": "5.1.0", + "key": "multidict", + "package_name": "multidict", + "required_version": ">=4.0" + }, + { + "installed_version": "3.10.0.0", + "key": "typing-extensions", + "package_name": "typing-extensions", + "required_version": ">=3.7.4" + } + ], + "package": { + "installed_version": "1.6.3", + "key": "yarl", + "package_name": "yarl" + } + }, { "dependencies": [], "package": { - "installed_version": "3.4.0", + "installed_version": "3.4.1", "key": "zipp", "package_name": "zipp" } diff --git a/scripts/requirements-gen.txt b/scripts/requirements-gen.txt index b2c2530e..ad16de61 100644 --- a/scripts/requirements-gen.txt +++ b/scripts/requirements-gen.txt @@ -1,6 +1,8 @@ +aiohttp ecdsa eth-hash[pycryptodome]==0.2.0 fastecdsa +frozendict==1.2 lark-parser==0.8.5 marshmallow-dataclass>=7.1.0 marshmallow-enum @@ -11,4 +13,5 @@ numpy pipdeptree pytest sympy +typeguard Web3 diff --git a/scripts/requirements.txt b/scripts/requirements.txt index 4ab7f7c5..0d29d0e2 100644 --- a/scripts/requirements.txt +++ b/scripts/requirements.txt @@ -1,12 +1,14 @@ # This file is autogenerated. Do not edit manually. -attrs==20.2.0 +aiohttp==3.7.4.post0 +async-timeout==3.0.1 +attrs==21.2.0 base58==2.1.0 bitarray==1.2.2 certifi==2020.12.5 chardet==4.0.0 cytoolz==0.11.0 -ecdsa==0.16.0 +ecdsa==0.17.0 eth-abi==2.1.1 eth-account==0.5.4 eth-hash[pycryptodome]==0.2.0 @@ -14,45 +16,49 @@ eth-keyfile==0.5.1 eth-keys==0.3.3 eth-rlp==0.2.1 eth-typing==2.2.2 -eth-utils==1.10.0 +eth-utils==1.9.5 fastecdsa==2.1.5 +frozendict==1.2 hexbytes==0.2.1 idna==2.10 -importlib-metadata==2.0.0 +importlib-metadata==4.3.1 iniconfig==1.1.1 ipfshttpclient==0.7.0a1 jsonschema==3.2.0 lark-parser==0.8.5 lru-dict==1.1.7 -marshmallow==3.8.0 -marshmallow-dataclass==8.1.0 +marshmallow==3.12.1 +marshmallow-dataclass==8.4.1 marshmallow-enum==1.5.1 marshmallow-oneofschema==2.1.0 -mpmath==1.1.0 +mpmath==1.2.1 multiaddr==0.0.9 +multidict==5.1.0 mypy-extensions==0.4.3 netaddr==0.8.0 -numpy==1.19.2 -packaging==20.4 +numpy==1.20.3 +packaging==20.9 parsimonious==0.8.1 -pipdeptree==1.0.0 +pipdeptree==2.0.0 pluggy==0.13.1 -protobuf==3.15.1 -py==1.9.0 -pycryptodome==3.9.9 +protobuf==3.17.1 +py==1.10.0 +pycryptodome==3.10.1 pyparsing==2.4.7 pyrsistent==0.17.3 -pytest==6.1.1 +pytest==6.2.4 requests==2.25.1 rlp==2.0.1 -six==1.15.0 -sympy==1.6.2 -toml==0.10.1 +six==1.16.0 +sympy==1.8 +toml==0.10.2 toolz==0.11.1 -typing-extensions==3.7.4.3 +typeguard==2.12.0 +typing-extensions==3.10.0.0 typing-inspect==0.6.0 -urllib3==1.26.3 +urllib3==1.26.5 varint==1.0.2 -web3==5.16.0 +web3==5.19.0 websockets==8.1 -zipp==3.4.0 +yarl==1.6.3 +zipp==3.4.1 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9f9a94e4..4db761f4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(services) add_subdirectory(starkware) diff --git a/src/cmake_utils/gen_py_lib.py b/src/cmake_utils/gen_py_lib.py index 10ab63a5..6cb2baae 100755 --- a/src/cmake_utils/gen_py_lib.py +++ b/src/cmake_utils/gen_py_lib.py @@ -33,10 +33,11 @@ def extract_licenses(filename: str) -> List[str]: prefix = 'License: ' - with open(filename, encoding='utf8') as fp: - for line in fp.readlines(): - if line.startswith(prefix): - return line.strip()[len(prefix):].split(',') + if os.path.isfile(filename): + with open(filename, encoding='utf8') as fp: + for line in fp.readlines(): + if line.startswith(prefix): + return line.strip()[len(prefix):].split(',') return [] @@ -58,6 +59,10 @@ def main(): parser.add_argument('--output', type=str, help='Output info file', required=True) parser.add_argument( '--py_exe_deps', type=str, nargs='*', required=True, help='List of executable dependencies') + parser.add_argument( + '--cmake_dir', type=str, nargs='?', help='Directory of this CMake target', required=False) + parser.add_argument( + '--prefix', type=str, nargs='?', help='Prefix of this CMake target', required=False) args = parser.parse_args() # Try to extract license if possible. @@ -82,6 +87,8 @@ def main(): lib_deps=args.lib_deps, py_exe_deps=args.py_exe_deps, licenses=licenses, + cmake_dir=args.cmake_dir, + prefix=args.prefix, ), fp, sort_keys=True, diff --git a/src/cmake_utils/gen_python_exe.py b/src/cmake_utils/gen_python_exe.py index be9e9f6e..86123e96 100755 --- a/src/cmake_utils/gen_python_exe.py +++ b/src/cmake_utils/gen_python_exe.py @@ -55,6 +55,7 @@ def main(): export BUILD_ROOT=$(realpath $(dirname $0)/{build_path_bash}) {cd_command} +source ${{BUILD_ROOT}}/{venv_dir_rel}/bin/activate {args.environment_variables} \ CMAKE_TARGET_NAME={args.name} \ ${{BUILD_ROOT}}/{venv_dir_rel}/bin/python -u -m {args.module} \ diff --git a/src/cmake_utils/pip_rules.cmake b/src/cmake_utils/pip_rules.cmake index 03c443af..b888973a 100644 --- a/src/cmake_utils/pip_rules.cmake +++ b/src/cmake_utils/pip_rules.cmake @@ -38,9 +38,15 @@ function(python_pip TARGET) COMMENT "Building wheel ${REQ} for ${INTERPRETER}" COMMAND ${CMAKE_COMMAND} -E make_directory ${LIB_DIR} COMMAND ${CMAKE_COMMAND} -E make_directory ${DOWNLOAD_DIR} - COMMAND ${INTERPRETER} -m pip wheel --no-deps -w ${DOWNLOAD_DIR}/ ${REQ} + COMMAND + ${INTERPRETER} -m pip wheel --no-deps -w ${DOWNLOAD_DIR}/ ${REQ} ${PIP_INSTALL_ARGS} # Extract wheel. COMMAND cd ${LIB_DIR} && ${CMAKE_COMMAND} -E tar xzf ${DOWNLOAD_DIR}/*.whl + # Some wheels may put their files at /{name}-{version}.data/(pure|plat)lib/, instead of under + # the root directory. See https://www.python.org/dev/peps/pep-0427/#id24. + # Copy the files from there. Suppress errors, which happen most of the times when this + # subdirectory does not exist. + COMMAND cp -r ${LIB_DIR}/*.data/*lib/* ${LIB_DIR}/ > /dev/null 2>&1 || true # Cleanup download. COMMAND ${CMAKE_COMMAND} -E remove_directory ${DOWNLOAD_DIR} # Timestamp. diff --git a/src/cmake_utils/python_rules.cmake b/src/cmake_utils/python_rules.cmake index 59793d37..9be862ed 100644 --- a/src/cmake_utils/python_rules.cmake +++ b/src/cmake_utils/python_rules.cmake @@ -90,6 +90,7 @@ function(python_lib LIB) endforeach() get_lib_info_file(INFO_FILE ${LIB}) + file(RELATIVE_PATH CMAKE_DIR ${CMAKE_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) add_custom_command( OUTPUT ${INFO_FILE} COMMAND ${GEN_PY_LIB_EXECUTABLE} @@ -99,6 +100,8 @@ function(python_lib LIB) --lib_deps ${ARGS_LIBS} --output ${INFO_FILE} --py_exe_deps ${ARGS_PY_EXE_DEPENDENCIES} + --cmake_dir ${CMAKE_DIR} + --prefix ${ARGS_PREFIX} DEPENDS ${GEN_PY_LIB_EXECUTABLE} ${DEP_INFO} ${UNITED_LIBS} ${ARGS_PY_EXE_DEPENDENCIES} ${ALL_FILE_DEPS} ${LIB}_copy_files ) diff --git a/src/services/CMakeLists.txt b/src/services/CMakeLists.txt new file mode 100644 index 00000000..46d7aaf5 --- /dev/null +++ b/src/services/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(external_api) +add_subdirectory(everest) diff --git a/src/services/everest/CMakeLists.txt b/src/services/everest/CMakeLists.txt new file mode 100644 index 00000000..f30a541c --- /dev/null +++ b/src/services/everest/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(api) +add_subdirectory(definitions) diff --git a/src/services/everest/api/CMakeLists.txt b/src/services/everest/api/CMakeLists.txt new file mode 100644 index 00000000..2222a2a6 --- /dev/null +++ b/src/services/everest/api/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(feeder_gateway) +add_subdirectory(gateway) diff --git a/src/services/everest/api/feeder_gateway/CMakeLists.txt b/src/services/everest/api/feeder_gateway/CMakeLists.txt new file mode 100644 index 00000000..7fedfc08 --- /dev/null +++ b/src/services/everest/api/feeder_gateway/CMakeLists.txt @@ -0,0 +1,9 @@ +python_lib(everest_feeder_gateway_client_lib + PREFIX services/everest/api/feeder_gateway + + FILES + feeder_gateway_client.py + + LIBS + services_external_api_lib +) diff --git a/src/services/everest/api/feeder_gateway/feeder_gateway_client.py b/src/services/everest/api/feeder_gateway/feeder_gateway_client.py new file mode 100644 index 00000000..d552bc1d --- /dev/null +++ b/src/services/everest/api/feeder_gateway/feeder_gateway_client.py @@ -0,0 +1,19 @@ +import json +from typing import ClassVar + +from services.external_api.base_client import BaseClient + + +class EverestFeederGatewayClient(BaseClient): + """ + Base class to FeederGatewayClient classes. + """ + + prefix: ClassVar[str] = '/feeder_gateway' + + async def is_alive(self) -> str: + return await self._send_request(send_method='GET', uri='/is_alive') + + async def get_last_batch_id(self) -> int: + raw_response = await self._send_request(send_method='GET', uri='/get_last_batch_id') + return json.loads(raw_response) diff --git a/src/services/everest/api/gateway/CMakeLists.txt b/src/services/everest/api/gateway/CMakeLists.txt new file mode 100644 index 00000000..72f03b8c --- /dev/null +++ b/src/services/everest/api/gateway/CMakeLists.txt @@ -0,0 +1,28 @@ +python_lib(everest_transaction_type_lib + PREFIX services/everest/api/gateway + + FILES + transaction_type.py +) + +python_lib(everest_transaction_lib + PREFIX services/everest/api/gateway + + FILES + transaction.py + + LIBS + starkware_utils_lib + pip_marshmallow_oneofschema +) + +python_lib(everest_gateway_client_lib + PREFIX services/everest/api/gateway + + FILES + gateway_client.py + + LIBS + everest_transaction_lib + services_external_api_lib +) diff --git a/src/services/everest/api/gateway/gateway_client.py b/src/services/everest/api/gateway/gateway_client.py new file mode 100644 index 00000000..7a357492 --- /dev/null +++ b/src/services/everest/api/gateway/gateway_client.py @@ -0,0 +1,26 @@ +import json +from typing import ClassVar, Dict + +from services.everest.api.gateway.transaction import EverestAddTransactionRequest +from services.external_api.base_client import BaseClient + + +class EverestGatewayClient(BaseClient): + """ + Base class to GatewayClient classes. + """ + + prefix: ClassVar[str] = '/gateway' + + async def is_alive(self) -> str: + return await self._send_request(send_method='GET', uri='/is_alive') + + async def add_transaction_request( + self, add_tx_request: EverestAddTransactionRequest) -> Dict[str, str]: + raw_response = await self._send_request( + send_method='POST', uri='/add_transaction', data=add_tx_request.dumps()) + return json.loads(raw_response) + + async def get_first_unused_tx_id(self) -> int: + response = await self._send_request(send_method='GET', uri='/get_first_unused_tx_id') + return json.loads(response) diff --git a/src/services/everest/api/gateway/transaction.py b/src/services/everest/api/gateway/transaction.py new file mode 100644 index 00000000..78c5c20e --- /dev/null +++ b/src/services/everest/api/gateway/transaction.py @@ -0,0 +1,19 @@ +from typing import ClassVar, Type + +import marshmallow_oneofschema + +from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass + + +class EverestTransaction(ValidatedMarshmallowDataclass): + """ + Base class of application-specific external transaction base classes. + Contains the API of an external transaction. + """ + + Schema: ClassVar[Type[marshmallow_oneofschema.OneOfSchema]] + + +class EverestAddTransactionRequest(ValidatedMarshmallowDataclass): + tx: EverestTransaction + tx_id: int diff --git a/src/services/everest/api/gateway/transaction_type.py b/src/services/everest/api/gateway/transaction_type.py new file mode 100644 index 00000000..5b7acb5a --- /dev/null +++ b/src/services/everest/api/gateway/transaction_type.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class TransactionTypeBase(Enum): + """ + Base class of all transaction type enums. + Do not add enum members to this class, only functionality. + See: https://docs.python.org/3/library/enum.html#restricted-enum-subclassing. + """ diff --git a/src/services/everest/definitions/CMakeLists.txt b/src/services/everest/definitions/CMakeLists.txt new file mode 100644 index 00000000..dfff254c --- /dev/null +++ b/src/services/everest/definitions/CMakeLists.txt @@ -0,0 +1,10 @@ +python_lib(everest_definitions_lib + PREFIX services/everest/definitions + + FILES + fields.py + + LIBS + starkware_utils_lib + pip_marshmallow +) diff --git a/src/services/everest/definitions/fields.py b/src/services/everest/definitions/fields.py new file mode 100644 index 00000000..e2614eb0 --- /dev/null +++ b/src/services/everest/definitions/fields.py @@ -0,0 +1,9 @@ +import marshmallow.fields as mfields + +from starkware.starkware_utils.field_validators import validate_non_negative + +# Fields data: validation data, dataclass metadata. +tx_id_marshmallow_field = mfields.Integer( + strict=True, required=True, validate=validate_non_negative('tx_id')) + +tx_id_field_metadata = dict(marshmallow_field=tx_id_marshmallow_field) diff --git a/src/services/external_api/CMakeLists.txt b/src/services/external_api/CMakeLists.txt new file mode 100644 index 00000000..d8ac5620 --- /dev/null +++ b/src/services/external_api/CMakeLists.txt @@ -0,0 +1,12 @@ +python_lib(services_external_api_lib + PREFIX services/external_api + + FILES + base_client.py + has_uri_prefix.py + ${SERVICES_EXTERNAL_API_LIB_ADDITIONAL_FILES} + + LIBS + pip_aiohttp + ${SERVICES_EXTERNAL_API_LIB_ADDITIONAL_LIBS} +) diff --git a/src/services/external_api/base_client.py b/src/services/external_api/base_client.py new file mode 100644 index 00000000..5d8169a9 --- /dev/null +++ b/src/services/external_api/base_client.py @@ -0,0 +1,120 @@ +import asyncio +import dataclasses +import logging +import os +import ssl +from http import HTTPStatus +from typing import Any, Dict, Optional, Sequence, Union +from urllib.parse import urljoin + +import aiohttp + +from services.external_api.has_uri_prefix import HasUriPrefix + +logger = logging.getLogger(__name__) + + +class BadRequest(Exception): + """ + Base class to exceptions raised by BaseClient and its derived classes. + """ + + def __init__(self, status_code: int, text: str): + self.status_code = status_code + self.text = text + + def __repr__(self) -> str: + return f'HTTP error ocurred. Status: {self.status_code}. Text: {self.text}' + + def __str__(self) -> str: + """ + Overrides base's str method, which returns an empty string (so it falls back to repr). + """ + return self.__repr__() + + +@dataclasses.dataclass(frozen=True) +class RetryConfig: + """ + A configuration defining the retry protocol for failed HTTP requests. + """ + + # Set n_retries == -1 for unlimited retries (for any error type). + n_retries: int = 30 + retry_codes: Sequence[HTTPStatus] = ( + HTTPStatus.BAD_GATEWAY, HTTPStatus.SERVICE_UNAVAILABLE, HTTPStatus.GATEWAY_TIMEOUT) + + +class BaseClient(HasUriPrefix): + """ + A base class for HTTP clients. + """ + + def __init__( + self, url: str, certificates_path: Optional[str] = None, + retry_config: Optional[RetryConfig] = None): + self.url = url + self.ssl_context: Optional[ssl.SSLContext] = None + + self.retry_config = RetryConfig() if retry_config is None else retry_config + assert self.retry_config.n_retries > 0 or self.retry_config.n_retries == -1, \ + 'RetryConfig n_retries parameter value must be either a positive int or equals to -1.' + + if certificates_path is not None: + self.ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLSv1_2) + self.ssl_context.verify_mode = ssl.CERT_REQUIRED + self.ssl_context.check_hostname = True + + self.ssl_context.load_cert_chain( + certfile=os.path.join(certificates_path, 'user.crt'), + keyfile=os.path.join(certificates_path, 'user.key')) + + self.ssl_context.load_verify_locations(os.path.join(certificates_path, 'server.crt')) + + async def _send_request( + self, send_method: str, uri: str, + data: Optional[Union[str, Dict[str, Any]]] = None) -> str: + """ + Sends an HTTP request to the target URI. + Retries upon failure according to the retry configuration: + 1. In case of unlimited retries (n_retries == -1): always retries upon failure. + 2. In case of limited retries (n_retries > 0): + a. Retries n_retries times for specified error types. + b. Raises an exception immediately for other error types. + """ + url = urljoin(base=self.url, url=self.format_uri(uri)) + + limited_retries = self.retry_config.n_retries > 0 + # n_retries > 0 means limited retries; n_retries == -1 means unlimited retries. + n_retries_left = self.retry_config.n_retries + + while True: + n_retries_left -= 1 + + try: + async with aiohttp.TCPConnector(ssl=self.ssl_context) as connector: + async with aiohttp.ClientSession(connector=connector) as session: + async with session.request( + method=send_method, url=url, data=data) as response: + text = await response.text() + if response.status != HTTPStatus.OK: + raise BadRequest(status_code=response.status, text=text) + + return text + except aiohttp.ClientError: + if limited_retries and n_retries_left == 0: + raise + + logger.error('ClientConnectorError, retrying...', exc_info=True) + except BadRequest as exception: + if limited_retries and ( + n_retries_left == 0 or + exception.status_code not in self.retry_config.retry_codes): + raise + + logger.error(f'BadRequest with code {exception.status_code}, retrying...') + + await asyncio.sleep(1) + + async def is_alive(self) -> str: + return await self._send_request(send_method='GET', uri='/is_alive') diff --git a/src/services/external_api/has_uri_prefix.py b/src/services/external_api/has_uri_prefix.py new file mode 100644 index 00000000..2ded7f0a --- /dev/null +++ b/src/services/external_api/has_uri_prefix.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +from typing import cast + + +class HasUriPrefix(ABC): + """ + A base class of HTTP Gateway services. + """ + @property + @classmethod + @abstractmethod + def prefix(cls) -> str: + """ + Returns the prefix of the gateway URIs. + Subclasses should define it as a class variable. + """ + + @classmethod + def format_uri(cls, name: str) -> str: + """ + Concatenates cls.prefix with given URI. + """ + prefix = cast(str, cls.prefix) # Mypy sees the property as a callable. + return name if len(prefix) == 0 else f'{cls.prefix}{name}' diff --git a/src/starkware/CMakeLists.txt b/src/starkware/CMakeLists.txt index 3bb04c3d..e0ed2724 100644 --- a/src/starkware/CMakeLists.txt +++ b/src/starkware/CMakeLists.txt @@ -1,3 +1,5 @@ -add_subdirectory(crypto) add_subdirectory(cairo) +add_subdirectory(crypto) add_subdirectory(python) +add_subdirectory(starknet) +add_subdirectory(starkware_utils) diff --git a/src/starkware/cairo/apps/starkex2_0/CMakeLists.txt b/src/starkware/cairo/apps/starkex2_0/CMakeLists.txt index a986645c..f3b4641d 100644 --- a/src/starkware/cairo/apps/starkex2_0/CMakeLists.txt +++ b/src/starkware/cairo/apps/starkex2_0/CMakeLists.txt @@ -9,6 +9,7 @@ full_python_test(starkex2_0_program_test TESTED_MODULES starkware/cairo/apps/starkex2_0 FILES + __start__.cairo common/cairo_builtins.cairo common/dict.cairo common/merkle_multi_update.cairo diff --git a/src/starkware/cairo/apps/starkex2_0/__start__.cairo b/src/starkware/cairo/apps/starkex2_0/__start__.cairo new file mode 100644 index 00000000..6a653fc1 --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/__start__.cairo @@ -0,0 +1,10 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +# Add the initial code that was removed from the compiler in order to get the correct program hash. +# This code is not required for the execution of the program using the bootloader. +__start__: +call rel 748 +jmp rel 0 diff --git a/src/starkware/cairo/apps/starkex2_0/common/cairo_builtins.cairo b/src/starkware/cairo/apps/starkex2_0/common/cairo_builtins.cairo index ce6768d8..2c76ef6f 100644 --- a/src/starkware/cairo/apps/starkex2_0/common/cairo_builtins.cairo +++ b/src/starkware/cairo/apps/starkex2_0/common/cairo_builtins.cairo @@ -15,10 +15,3 @@ struct SignatureBuiltin: member pub_key : felt member message : felt end - -# A representation of a CheckpointsBuiltin struct, specifying the checkpoints builtin memory -# structure. -struct CheckpointsBuiltin: - member required_pc : felt - member required_fp : felt -end diff --git a/src/starkware/cairo/apps/starkex2_0/common/dict.cairo b/src/starkware/cairo/apps/starkex2_0/common/dict.cairo index a20f037b..e2345480 100644 --- a/src/starkware/cairo/apps/starkex2_0/common/dict.cairo +++ b/src/starkware/cairo/apps/starkex2_0/common/dict.cairo @@ -95,7 +95,7 @@ func squash_dict_inner( locals.first_value = first_access.prev_value locals.first_value = dict_diff.prev_value - # Skip loop non-deterministically if necessary. + # Skip loop nondeterministically if necessary. %{ memory[fp + ids.Locals.should_skip_loop] = 0 if current_access_indices else 1 %} jmp skip_loop if [fp + Locals.should_skip_loop] != 0 diff --git a/src/starkware/cairo/apps/starkex2_0/execute_batch.cairo b/src/starkware/cairo/apps/starkex2_0/execute_batch.cairo index 72a3dcc0..f2439b74 100644 --- a/src/starkware/cairo/apps/starkex2_0/execute_batch.cairo +++ b/src/starkware/cairo/apps/starkex2_0/execute_batch.cairo @@ -113,7 +113,8 @@ func execute_batch( # Call execute_batch recursively. # Make a copy of the first argument to avoid a compiler optimization that was added after the # code was deployed. - tempvar modification_ptr = modification_ptr + [ap] = modification_ptr; ap++ + let modification_ptr = cast([ap - 1], ModificationOutput*) return execute_batch( modification_ptr=modification_ptr, conditional_transfer_ptr=conditional_transfer_ptr, diff --git a/src/starkware/cairo/apps/starkex2_0/execute_transfer.cairo b/src/starkware/cairo/apps/starkex2_0/execute_transfer.cairo index 0e371d6d..a5b30aef 100644 --- a/src/starkware/cairo/apps/starkex2_0/execute_transfer.cairo +++ b/src/starkware/cairo/apps/starkex2_0/execute_transfer.cairo @@ -78,7 +78,8 @@ func execute_transfer( # Call vault_update for the receiver. # Make a copy of the first argument to avoid a compiler optimization that was added after the # code was deployed. - tempvar range_check_ptr = sender_vault_update_ret.range_check_ptr + [ap] = sender_vault_update_ret.range_check_ptr; ap++ + let range_check_ptr = [ap - 1] let receiver_vault_update_ret = vault_update_diff( range_check_ptr=range_check_ptr, diff=amount, diff --git a/src/starkware/cairo/apps/starkex2_0/hash_vault_ptr_dict.cairo b/src/starkware/cairo/apps/starkex2_0/hash_vault_ptr_dict.cairo index ab03a3e5..4bf24a6f 100644 --- a/src/starkware/cairo/apps/starkex2_0/hash_vault_ptr_dict.cairo +++ b/src/starkware/cairo/apps/starkex2_0/hash_vault_ptr_dict.cairo @@ -46,7 +46,8 @@ func hash_vault_ptr_dict( # Make a copy of the first argument to avoid a compiler optimization that was added after the # code was deployed. - tempvar hash_ptr = prev_hash_res.hash_ptr + [ap] = prev_hash_res.hash_ptr; ap++ + let hash_ptr = cast([ap - 1], HashBuiltin*) let new_hash_res = hash_vault_state_ptr( hash_ptr=hash_ptr, vault_state_ptr=cast(vault_access.new_value, VaultState*)) hashed_vault_access.new_value = new_hash_res.vault_hash @@ -54,7 +55,8 @@ func hash_vault_ptr_dict( # Tail call. # Make a copy of the first argument to avoid a compiler optimization that was added after the # code was deployed. - tempvar hash_ptr = new_hash_res.hash_ptr + [ap] = new_hash_res.hash_ptr; ap++ + let hash_ptr = cast([ap - 1], HashBuiltin*) return hash_vault_ptr_dict( hash_ptr=hash_ptr, vault_ptr_dict=vault_ptr_dict + DictAccess.SIZE, diff --git a/src/starkware/cairo/apps/starkex2_0/main.cairo b/src/starkware/cairo/apps/starkex2_0/main.cairo index 89492173..a9d7af34 100644 --- a/src/starkware/cairo/apps/starkex2_0/main.cairo +++ b/src/starkware/cairo/apps/starkex2_0/main.cairo @@ -5,6 +5,7 @@ %builtins output pedersen range_check ecdsa +from starkware.cairo.apps.starkex2_0.__start__ import __start__ from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin, SignatureBuiltin from starkware.cairo.apps.starkex2_0.common.dict import DictAccess, squash_dict from starkware.cairo.apps.starkex2_0.common.merkle_multi_update import merkle_multi_update @@ -112,7 +113,8 @@ func main( # Verify hashed_vault_dict consistency with the vault merkle root. # Make a copy of the first argument to avoid a compiler optimization that was added after the # code was deployed. - tempvar hash_ptr = hash_vault_dict_ptr + [ap] = hash_vault_dict_ptr; ap++ + let hash_ptr = cast([ap - 1], HashBuiltin*) with hash_ptr: merkle_multi_update( update_ptr=hashed_vault_dict, @@ -124,7 +126,8 @@ func main( # Verify squashed_order_dict consistency with the order merkle root. # Make a copy of the first argument to avoid a compiler optimization that was added after # the code was deployed. - tempvar hash_ptr = hash_ptr + [ap] = hash_ptr; ap++ + let hash_ptr = cast([ap - 1], HashBuiltin*) merkle_multi_update( update_ptr=squashed_order_dict, n_updates=squashed_order_dict_segment_size / DictAccess.SIZE, diff --git a/src/starkware/cairo/apps/starkex2_0/starkex2_0_program_test.py b/src/starkware/cairo/apps/starkex2_0/starkex2_0_program_test.py index a6ed0c48..e2746fe1 100644 --- a/src/starkware/cairo/apps/starkex2_0/starkex2_0_program_test.py +++ b/src/starkware/cairo/apps/starkex2_0/starkex2_0_program_test.py @@ -29,6 +29,7 @@ def test_program_hash(): PROGRAM_MAIN_FILE, f'--output={compiled_program.name}', f'--cairo_path={CAIRO_PATH}', + '--no_opt_unused_functions', ]) program_hash = subprocess.check_output([ f'{CAIRO_HASH_PROGRAM_EXE}', diff --git a/src/starkware/cairo/bootloader/CMakeLists.txt b/src/starkware/cairo/bootloader/CMakeLists.txt index b37e8a39..37a5134b 100644 --- a/src/starkware/cairo/bootloader/CMakeLists.txt +++ b/src/starkware/cairo/bootloader/CMakeLists.txt @@ -54,7 +54,7 @@ python_lib(cairo_bootloader_generate_fact_lib LIBS cairo_bootloader_fact_topology_lib cairo_hash_program_lib - cairo_relocatable + cairo_relocatable_lib cairo_vm_lib pip_eth_hash pip_pycryptodome diff --git a/src/starkware/cairo/bootloader/fact_topology.py b/src/starkware/cairo/bootloader/fact_topology.py index 1f1da92c..8fa15e23 100644 --- a/src/starkware/cairo/bootloader/fact_topology.py +++ b/src/starkware/cairo/bootloader/fact_topology.py @@ -8,13 +8,13 @@ GPS_FACT_TOPOLOGY = 'gps_fact_topology' -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class FactTopology: tree_structure: List[int] page_sizes: List[int] -@marshmallow_dataclass.dataclass +@marshmallow_dataclass.dataclass(frozen=True) class FactTopologiesFile: fact_topologies: List[FactTopology] Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema @@ -24,7 +24,7 @@ def load_fact_topologies(path) -> List[FactTopology]: return FactTopologiesFile.Schema().load(json.load(open(path))).fact_topologies -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class FactInfo: program_output: List[int] fact_topology: FactTopology diff --git a/src/starkware/cairo/common/CMakeLists.txt b/src/starkware/cairo/common/CMakeLists.txt index 85f7f4f0..616e1500 100644 --- a/src/starkware/cairo/common/CMakeLists.txt +++ b/src/starkware/cairo/common/CMakeLists.txt @@ -12,6 +12,8 @@ python_lib(cairo_common_lib hash_chain.py hash_state.cairo hash.cairo + invoke.cairo + math_cmp.cairo math_utils.py math.cairo memcpy.cairo @@ -19,10 +21,13 @@ python_lib(cairo_common_lib merkle_update.cairo registers.cairo serialize.cairo + set.cairo signature.cairo small_merkle_tree.cairo small_merkle_tree.py squash_dict.cairo + structs.py + uint256.cairo ${CAIRO_COMMON_LIB_ADDITIONAL_FILES} LIBS @@ -30,3 +35,19 @@ python_lib(cairo_common_lib starkware_merkle_tree_lib ${CAIRO_COMMON_LIB_ADDITIONAL_LIBS} ) + +python_lib(cairo_function_runner_lib + PREFIX starkware/cairo/common + + FILES + cairo_function_runner.py + + LIBS + cairo_common_lib + cairo_compile_lib + cairo_run_builtins_lib + cairo_run_lib + cairo_tracer_lib + cairo_vm_crypto_lib + cairo_vm_lib +) diff --git a/src/starkware/cairo/common/cairo_builtins.cairo b/src/starkware/cairo/common/cairo_builtins.cairo index fabceb40..56d54d59 100644 --- a/src/starkware/cairo/common/cairo_builtins.cairo +++ b/src/starkware/cairo/common/cairo_builtins.cairo @@ -10,10 +10,3 @@ struct SignatureBuiltin: member pub_key : felt member message : felt end - -# A representation of a CheckpointsBuiltin struct, specifying the checkpoints builtin memory -# structure. -struct CheckpointsBuiltin: - member required_pc : felt - member required_fp : felt -end diff --git a/src/starkware/cairo/common/cairo_function_runner.py b/src/starkware/cairo/common/cairo_function_runner.py new file mode 100644 index 00000000..aea2516c --- /dev/null +++ b/src/starkware/cairo/common/cairo_function_runner.py @@ -0,0 +1,150 @@ +from collections.abc import Iterable +from typing import Any, Dict, Optional, Union, cast + +from starkware.cairo.common.structs import CairoStructFactory +from starkware.cairo.lang.builtins.hash.hash_builtin_runner import HashBuiltinRunner +from starkware.cairo.lang.builtins.range_check.range_check_builtin_runner import ( + RangeCheckBuiltinRunner) +from starkware.cairo.lang.builtins.signature.signature_builtin_runner import SignatureBuiltinRunner +from starkware.cairo.lang.compiler.identifier_definition import LabelDefinition +from starkware.cairo.lang.compiler.program import Program +from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.cairo.lang.tracer.tracer import trace_runner +from starkware.cairo.lang.vm.cairo_runner import CairoRunner, process_ecdsa, verify_ecdsa_sig +from starkware.cairo.lang.vm.crypto import pedersen_hash +from starkware.cairo.lang.vm.output_builtin_runner import OutputBuiltinRunner +from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue +from starkware.cairo.lang.vm.security import SecurityError, verify_secure_runner +from starkware.cairo.lang.vm.vm import VmException + + +class CairoFunctionRunner(CairoRunner): + def __init__(self, program, layout='plain'): + super().__init__(program=program, layout=layout) + + pedersen_builtin = HashBuiltinRunner( + name='pedersen', included=True, ratio=32, hash_func=pedersen_hash) + self.builtin_runners['pedersen_builtin'] = pedersen_builtin + range_check_builtin = RangeCheckBuiltinRunner( + included=True, ratio=None, inner_rc_bound=2 ** 16, n_parts=8) + self.builtin_runners['range_check_builtin'] = range_check_builtin + output_builtin = OutputBuiltinRunner(included=True) + self.builtin_runners['output_builtin'] = output_builtin + signature_builtin = SignatureBuiltinRunner( + name='ecdsa', included=True, ratio=None, process_signature=process_ecdsa, + verify_signature=verify_ecdsa_sig) + self.builtin_runners['ecdsa_builtin'] = signature_builtin + + self.initialize_segments() + + @property + def pedersen_builtin(self) -> HashBuiltinRunner: + return cast(HashBuiltinRunner, self.builtin_runners['pedersen_builtin']) + + @property + def range_check_builtin(self) -> RangeCheckBuiltinRunner: + return cast(RangeCheckBuiltinRunner, self.builtin_runners['range_check_builtin']) + + @property + def output_builtin(self) -> OutputBuiltinRunner: + return cast(OutputBuiltinRunner, self.builtin_runners['output_builtin']) + + @property + def ecdsa_builtin(self) -> SignatureBuiltinRunner: + return cast(SignatureBuiltinRunner, self.builtin_runners['ecdsa_builtin']) + + def assert_eq(self, arg: MaybeRelocatable, expected_value, apply_modulo: bool = True): + """ + Asserts that arg is the Cairo representation of expected_value. + If expected_value is Iterable then arg is interpreted as a pointer to a list + and assert_eq is called recursively on all the items in expected_value. + If apply_modulo=True, all the integers are taken modulo the program's prime. + """ + assert isinstance(arg, (int, RelocatableValue)), f'Expecting MaybeRelocatable got {arg}' + + if isinstance(expected_value, Iterable): + for idx, value in enumerate(expected_value): + self.assert_eq(self.vm_memory[arg + idx], value, apply_modulo=apply_modulo) + return + + if apply_modulo and isinstance(arg, int): + expected_value = expected_value % self.program.prime + + assert arg == expected_value, f'{arg} does not equal expected value {expected_value}.' + + def run( + self, func_name: str, *args, hint_locals: Optional[Dict[str, Any]] = None, + static_locals: Optional[Dict[str, Any]] = None, + verify_secure: Optional[bool] = None, trace_on_failure: bool = False, + apply_modulo_to_args: Optional[bool] = None, use_full_name: bool = False, **kwargs): + """ + Runs func_name(*args). + args are converted to Cairo-friendly ones using gen_arg. + + Additional params: + verify_secure - Run verify_secure_runner to do extra verifications. + trace_on_failure - Run the tracer in case of failure to help debugging. + apply_modulo_to_args - Apply modulo operation on integer arguments. + use_full_name - Treat func_name as a fully qualified identifer name, instance of a relative + one. + """ + assert isinstance(self.program, Program) + structs_factory = CairoStructFactory.from_program(program=self.program) + full_args_struct = structs_factory.build_func_args( + func=ScopedName.from_string(scope=func_name)) + all_args = full_args_struct(*args, **kwargs) + + entrypoint: Union[str, int] + if use_full_name: + identifier = self.program.identifiers.get_by_full_name( + name=ScopedName.from_string(scope=func_name)) + assert isinstance(identifier, LabelDefinition) + entrypoint = identifier.pc + else: + entrypoint = func_name + + try: + self.run_from_entrypoint( + entrypoint, *all_args, hint_locals=hint_locals, static_locals=static_locals, + verify_secure=verify_secure, apply_modulo_to_args=apply_modulo_to_args) + except (VmException, SecurityError, AssertionError) as ex: + if trace_on_failure: + print(f"""\ +Got {type(ex).__name__} exception during the execution of {func_name}: +{str(ex)} +""") + trace_runner(runner=self) + raise + + def run_from_entrypoint( + self, entrypoint: Union[str, int], *args, hint_locals: Optional[Dict[str, Any]] = None, + static_locals: Optional[Dict[str, Any]] = None, max_steps: Optional[int] = None, + verify_secure: Optional[bool] = None, apply_modulo_to_args: Optional[bool] = None): + """ + Runs the program from the given entrypoint. + + Additional params: + verify_secure - Run verify_secure_runner to do extra verifications. + apply_modulo_to_args - Apply modulo operation on integer arguments. + """ + if hint_locals is None: + hint_locals = {} + + if verify_secure is None: + verify_secure = True + + if apply_modulo_to_args is None: + apply_modulo_to_args = True + + real_args = [self.gen_arg(arg=x, apply_modulo_to_args=apply_modulo_to_args) for x in args] + end = self.initialize_function_entrypoint(entrypoint=entrypoint, args=real_args) + self.initialize_vm(hint_locals=hint_locals, static_locals=static_locals) + + self.run_until_pc(addr=end, max_steps=max_steps) + self.end_run() + + if verify_secure: + verify_secure_runner(runner=self, verify_builtins=False) + + def get_return_values(self, n_ret: int): + return self.vm_memory.get_range(addr=self.vm.run_context.ap - n_ret, size=n_ret) diff --git a/src/starkware/cairo/common/find_element.cairo b/src/starkware/cairo/common/find_element.cairo index fe2571af..5065eb71 100644 --- a/src/starkware/cairo/common/find_element.cairo +++ b/src/starkware/cairo/common/find_element.cairo @@ -1,8 +1,10 @@ from starkware.cairo.common.math import assert_le, assert_nn_le +const FIND_ELEMENT_RANGE_CHECK_USAGE = 2 + # Finds an element in the array whose first field is key and returns a pointer # to this element. -# Since cairo is non-deterministic this is an O(1) operation. +# Since cairo is nondeterministic this is an O(1) operation. # Note however that if the array has multiple elements with said key the function may return any # of those elements. # @@ -25,21 +27,35 @@ func find_element{range_check_ptr}(array_ptr : felt*, elm_size, n_elms, key) -> alloc_locals local index %{ + array_ptr = ids.array_ptr + elm_size = ids.elm_size + assert isinstance(elm_size, int) and elm_size > 0, \ + f'Invalid value for elm_size. Got: {elm_size}.' + key = ids.key + if '__find_element_index' in globals(): ids.index = __find_element_index - found_key = memory[ids.array_ptr + ids.elm_size * __find_element_index] - assert found_key == ids.key, \ + found_key = memory[array_ptr + elm_size * __find_element_index] + assert found_key == key, \ f'Invalid index found in __find_element_index. index: {__find_element_index}, ' \ - f'expected key {ids.key}, found key: {found_key}.' + f'expected key {key}, found key: {found_key}.' # Delete __find_element_index to make sure it's not used for the next calls. del __find_element_index else: - for i in range(ids.n_elms): - if memory[ids.array_ptr + ids.elm_size * i] == ids.key: + n_elms = ids.n_elms + assert isinstance(n_elms, int) and n_elms >= 0, \ + f'Invalid value for n_elms. Got: {n_elms}.' + if '__find_element_max_size' in globals(): + assert n_elms <= __find_element_max_size, \ + f'find_element() can only be used with n_elms<={__find_element_max_size}. ' \ + f'Got: n_elms={n_elms}.' + + for i in range(n_elms): + if memory[array_ptr + elm_size * i] == key: ids.index = i break else: - raise ValueError(f'Key {ids.key} not found.') + raise ValueError(f'Key {key} was not found.') %} assert_nn_le(a=index, b=n_elms - 1) @@ -57,12 +73,25 @@ func search_sorted_lower{range_check_ptr}(array_ptr : felt*, elm_size, n_elms, k alloc_locals local index %{ - for i in range(ids.n_elms): - if memory[ids.array_ptr + ids.elm_size * i] >= ids.key: + array_ptr = ids.array_ptr + elm_size = ids.elm_size + assert isinstance(elm_size, int) and elm_size > 0, \ + f'Invalid value for elm_size. Got: {elm_size}.' + + n_elms = ids.n_elms + assert isinstance(n_elms, int) and n_elms >= 0, \ + f'Invalid value for n_elms. Got: {n_elms}.' + if '__find_element_max_size' in globals(): + assert n_elms <= __find_element_max_size, \ + f'find_element() can only be used with n_elms<={__find_element_max_size}. ' \ + f'Got: n_elms={n_elms}.' + + for i in range(n_elms): + if memory[array_ptr + elm_size * i] >= ids.key: ids.index = i break else: - ids.index = ids.n_elms + ids.index = n_elms %} assert_nn_le(a=index, b=n_elms) diff --git a/src/starkware/cairo/common/hash_state.cairo b/src/starkware/cairo/common/hash_state.cairo index f26e2c44..3500fde8 100644 --- a/src/starkware/cairo/common/hash_state.cairo +++ b/src/starkware/cairo/common/hash_state.cairo @@ -24,13 +24,48 @@ end # A helper function for 'hash_update', see its documentation. # Computes the hash of an array of items, not including its length. -func hash_update_inner{hash_ptr : HashBuiltin*}(curr_ptr : felt*, data_length, hash) -> (hash): +func hash_update_inner{hash_ptr : HashBuiltin*}( + data_ptr : felt*, data_length : felt, hash : felt) -> (hash : felt): if data_length == 0: return (hash=hash) end - let (res) = hash2(x=hash, y=[curr_ptr]) - return hash_update_inner(curr_ptr=curr_ptr + 1, data_length=data_length - 1, hash=res) + alloc_locals + local data_last_ptr : felt* = data_ptr + data_length - 1 + struct LoopLocals: + member data_ptr : felt* + member hash_ptr : HashBuiltin* + member cur_hash : felt + end + + # Set up first iteration locals. + let first_locals : LoopLocals* = cast(ap, LoopLocals*) + first_locals.data_ptr = data_ptr; ap++ + first_locals.hash_ptr = hash_ptr; ap++ + first_locals.cur_hash = hash; ap++ + + # Do{. + hash_loop: + let prev_locals : LoopLocals* = cast(ap - LoopLocals.SIZE, LoopLocals*) + tempvar n_remaining_elements = data_last_ptr - prev_locals.data_ptr + + # Compute hash(cur_hash, [data_ptr]). + prev_locals.hash_ptr.x = prev_locals.cur_hash + assert prev_locals.hash_ptr.y = [prev_locals.data_ptr] # Allocates one memory cell. + + # Set up next iteration locals. + let next_locals : LoopLocals* = cast(ap, LoopLocals*) + next_locals.data_ptr = prev_locals.data_ptr + 1; ap++ + next_locals.hash_ptr = prev_locals.hash_ptr + HashBuiltin.SIZE; ap++ + next_locals.cur_hash = prev_locals.hash_ptr.result; ap++ + + # } while(n_remaining_elements != 0). + jmp hash_loop if n_remaining_elements != 0 + + # Return values from final iteration. + let final_locals : LoopLocals* = cast(ap - LoopLocals.SIZE, LoopLocals*) + let hash_ptr = final_locals.hash_ptr + return (hash=final_locals.cur_hash) end # Adds each item in an array of items to the HashState. @@ -40,7 +75,7 @@ func hash_update{hash_ptr : HashBuiltin*}( new_hash_state_ptr : HashState*): alloc_locals let (hash) = hash_update_inner( - curr_ptr=data_ptr, data_length=data_length, hash=hash_state_ptr.current_hash) + data_ptr=data_ptr, data_length=data_length, hash=hash_state_ptr.current_hash) let (__fp__, _) = get_fp_and_pc() local new_hash_state : HashState new_hash_state.current_hash = hash diff --git a/src/starkware/cairo/common/invoke.cairo b/src/starkware/cairo/common/invoke.cairo new file mode 100644 index 00000000..8c884d54 --- /dev/null +++ b/src/starkware/cairo/common/invoke.cairo @@ -0,0 +1,21 @@ +# Calls func_ptr(args[0], args[1], ..., args[n_args - 1]) and forwards its return value. +# In order to convert a label to pc and use it as a value for the func_ptr argument, +# use get_label_location(). +func invoke(func_ptr, n_args : felt, args : felt*): + invoke_prepare_args(args_end=args + n_args, n_args=n_args) + call abs func_ptr + ret +end + +# Helper function for invoke(). +# Copies the memory range [args_end - n_args, args_end) to the memory range +# [final_ap - n_args, final_ap) where final_ap is the value of ap when the function returns. +func invoke_prepare_args(args_end : felt*, n_args : felt): + if n_args == 0: + return () + end + + invoke_prepare_args(args_end=args_end - 1, n_args=n_args - 1) + [ap] = [args_end - 1]; ap++ + return () +end diff --git a/src/starkware/cairo/common/math.cairo b/src/starkware/cairo/common/math.cairo index 0d289963..531abc33 100644 --- a/src/starkware/cairo/common/math.cairo +++ b/src/starkware/cairo/common/math.cairo @@ -2,7 +2,11 @@ # Verifies that value != 0. The proof will fail otherwise. func assert_not_zero(value): - %{ assert ids.value % PRIME != 0, f'assert_not_zero failed: {ids.value} = 0.' %} + %{ + from starkware.cairo.common.math_utils import assert_integer + assert_integer(ids.value) + assert ids.value % PRIME != 0, f'assert_not_zero failed: {ids.value} = 0.' + %} if value == 0: # If value == 0, add an unsatisfiable requirement. value = 1 @@ -13,7 +17,16 @@ end # Verifies that a != b. The proof will fail otherwise. func assert_not_equal(a, b): - %{ assert (ids.a - ids.b) % PRIME != 0, f'assert_not_equal failed: {ids.a} = {ids.b}.' %} + %{ + from starkware.cairo.lang.vm.relocatable import RelocatableValue + both_ints = isinstance(ids.a, int) and isinstance(ids.b, int) + both_relocatable = ( + isinstance(ids.a, RelocatableValue) and isinstance(ids.b, RelocatableValue) and + ids.a.segment_index == ids.b.segment_index) + assert both_ints or both_relocatable, \ + f'assert_not_equal failed: non-comparable values: {ids.a}, {ids.b}.' + assert (ids.a - ids.b) % PRIME != 0, f'assert_not_equal failed: {ids.a} = {ids.b}.' + %} if a == b: # If a == b, add an unsatisfiable requirement. [fp - 1] = [fp - 1] + 1 @@ -24,7 +37,11 @@ end # Verifies that a >= 0 (or more precisely 0 <= a < RANGE_CHECK_BOUND). func assert_nn{range_check_ptr}(a): - %{ assert 0 <= ids.a % PRIME < range_check_builtin.bound, f'a = {ids.a} is out of range.' %} + %{ + from starkware.cairo.common.math_utils import assert_integer + assert_integer(ids.a) + assert 0 <= ids.a % PRIME < range_check_builtin.bound, f'a = {ids.a} is out of range.' + %} a = [range_check_ptr] let range_check_ptr = range_check_ptr + 1 return () @@ -58,10 +75,9 @@ func assert_in_range{range_check_ptr}(value, lower, upper): return () end -# Asserts that a <= b. +# Asserts that a <= b. More specifically, asserts that b - a is in the range [0, 2**250). # -# Assumptions: -# a and b are in the range [0, 2**250). +# Prover assumptions: # PRIME - 2**250 > 2**(250 - 128) + 1 * RC_BOUND. func assert_le_250_bit{range_check_ptr}(a, b): let low = [range_check_ptr] @@ -69,20 +85,23 @@ func assert_le_250_bit{range_check_ptr}(a, b): let range_check_ptr = range_check_ptr + 2 const UPPER_BOUND = %[2**(250)%] const HIGH_PART_SHIFT = %[2**250 // 2**128 %] + tempvar diff = b - a %{ + from starkware.cairo.common.math_utils import as_int + # Soundness checks. assert range_check_builtin.bound == 2**128 assert ids.UPPER_BOUND == ids.HIGH_PART_SHIFT * range_check_builtin.bound - assert ids.a < ids.UPPER_BOUND, f'a={ids.a} is outside of the valid range.' - assert ids.b < ids.UPPER_BOUND, f'b={ids.b} is outside of the valid range.' - assert PRIME - ids.UPPER_BOUND > (ids.HIGH_PART_SHIFT + 1) * range_check_builtin.bound # Correctness check. - assert ids.a <= ids.b, f'a={ids.a} > b={ids.b}.' - %} + diff = as_int(ids.diff, PRIME) + values_msg = f'(a={as_int(ids.a, PRIME)}, b={as_int(ids.b, PRIME)}).' + assert diff < ids.UPPER_BOUND, f'(b - a)={diff} is outside of the valid range. {values_msg}' + assert PRIME - ids.UPPER_BOUND > (ids.HIGH_PART_SHIFT + 1) * range_check_builtin.bound - tempvar diff = b - a - %{ + assert diff >= 0, f'(b - a)={diff} < 0. {values_msg}' + + # Calculation for the assertion. ids.high = ids.diff // ids.HIGH_PART_SHIFT ids.low = ids.diff % ids.HIGH_PART_SHIFT %} @@ -110,7 +129,9 @@ func split_felt{range_check_ptr}(value) -> (high, low): let range_check_ptr = range_check_ptr + 2 %{ + from starkware.cairo.common.math_utils import assert_integer assert PRIME < 2**256 + assert_integer(ids.value) ids.low = ids.value & ((1 << 128) - 1) ids.high = ids.value >> 128 %} @@ -128,6 +149,9 @@ end # See split_felt() for more details. func assert_le_felt{range_check_ptr}(a, b): %{ + from starkware.cairo.common.math_utils import assert_integer + assert_integer(ids.a) + assert_integer(ids.b) assert (ids.a % PRIME) <= (ids.b % PRIME), \ f'a = {ids.a % PRIME} is not less than or equal to b = {ids.b % PRIME}.' %} @@ -147,6 +171,9 @@ end # that of b. func assert_lt_felt{range_check_ptr}(a, b): %{ + from starkware.cairo.common.math_utils import assert_integer + assert_integer(ids.a) + assert_integer(ids.b) assert (ids.a % PRIME) < (ids.b % PRIME), \ f'a = {ids.a % PRIME} is not less than b = {ids.b % PRIME}.' %} @@ -219,6 +246,8 @@ func unsigned_div_rem{range_check_ptr}(value, div) -> (q, r): let q = [range_check_ptr + 1] let range_check_ptr = range_check_ptr + 2 %{ + from starkware.cairo.common.math_utils import assert_integer + assert_integer(ids.div) assert 0 < ids.div <= PRIME // range_check_builtin.bound, \ f'div={hex(ids.div)} is out of the valid range.' ids.q, ids.r = divmod(ids.value, ids.div) @@ -245,16 +274,17 @@ func signed_div_rem{range_check_ptr}(value, div, bound) -> (q, r): let biased_q = [range_check_ptr + 1] # == q + bound. let range_check_ptr = range_check_ptr + 2 %{ - def as_int(val): - return val if val < PRIME // 2 else val - PRIME + from starkware.cairo.common.math_utils import as_int, assert_integer + assert_integer(ids.div) assert 0 < ids.div <= PRIME // range_check_builtin.bound, \ f'div={hex(ids.div)} is out of the valid range.' + assert_integer(ids.bound) assert ids.bound <= range_check_builtin.bound // 2, \ f'bound={hex(ids.bound)} is out of the valid range.' - int_value = as_int(ids.value) + int_value = as_int(ids.value, PRIME) q, ids.r = divmod(int_value, ids.div) assert -ids.bound <= q < ids.bound, \ diff --git a/src/starkware/cairo/common/math_cmp.cairo b/src/starkware/cairo/common/math_cmp.cairo new file mode 100644 index 00000000..a6dae692 --- /dev/null +++ b/src/starkware/cairo/common/math_cmp.cairo @@ -0,0 +1,76 @@ +from starkware.cairo.common.math import assert_le_felt, assert_lt_felt + +const RC_BOUND = %[ 2**128 %] + +# Returns 1 if value != 0. Returns 0 otherwise. +func is_not_zero(value) -> (res): + if value == 0: + return (res=0) + end + + return (res=1) +end + +# Returns 1 if a >= 0 (or more precisely 0 <= a < RANGE_CHECK_BOUND). +# Returns 0 otherwise. +func is_nn{range_check_ptr}(a) -> (res): + %{ memory[ap] = 0 if 0 <= (ids.a % PRIME) < range_check_builtin.bound else 1 %} + jmp out_of_range if [ap] != 0; ap++ + [range_check_ptr] = a + let range_check_ptr = range_check_ptr + 1 + return (res=1) + + out_of_range: + %{ memory[ap] = 0 if 0 <= ((-ids.a - 1) % PRIME) < range_check_builtin.bound else 1 %} + jmp need_felt_comparison if [ap] != 0; ap++ + assert [range_check_ptr] = (-a) - 1 + let range_check_ptr = range_check_ptr + 1 + return (res=0) + + need_felt_comparison: + assert_le_felt(RC_BOUND, a) + return (res=0) +end + +# Returns 1 if a <= b (or more precisely 0 <= b - a < RANGE_CHECK_BOUND). +# Returns 0 otherwise. +func is_le{range_check_ptr}(a, b) -> (res): + return is_nn(b - a) +end + +# Returns 1 of 0 <= a <= b < RANGE_CHECK_BOUND. +# Returns 0 otherwise. +func is_nn_le{range_check_ptr}(a, b) -> (res): + let (res) = is_nn(a) + if res == 0: + return (res=res) + end + return is_le(a, b) +end + +# Returns 1 if value is in the range [lower, upper). +# Returns 0 otherwise. +# Assumptions: +# upper - lower <= RC_BOUND +func is_in_range{range_check_ptr}(value, lower, upper) -> (res): + let (res) = is_le(lower, value) + if res == 0: + return (res=res) + end + return is_le(value, upper - 1) +end + +# Checks if the unsigned integer lift (as a number in the range [0, PRIME)) of a is lower than +# or equal to that of b. +# See split_felt() for more details. +# Returns 1 if true, 0 otherwise. +func is_le_felt{range_check_ptr}(a, b) -> (res): + %{ memory[ap] = 0 if (ids.a % PRIME) <= (ids.b % PRIME) else 1 %} + jmp not_le if [ap] != 0; ap++ + assert_le_felt(a, b) + return (res=1) + + not_le: + assert_lt_felt(b, a) + return (res=0) +end diff --git a/src/starkware/cairo/common/math_utils.py b/src/starkware/cairo/common/math_utils.py index 449c1f44..d9dee17f 100644 --- a/src/starkware/cairo/common/math_utils.py +++ b/src/starkware/cairo/common/math_utils.py @@ -1,8 +1,16 @@ +def assert_integer(val): + """ + Asserts that the input is an integer (and not relocatable value). + """ + assert isinstance(val, int), f'Expected integer, found: {val}.' + + def as_int(val, prime): """ Returns the lift of the given field element, val, as an integer in the range (-prime/2, prime/2). """ + assert_integer(val) return val if val < prime // 2 else val - prime diff --git a/src/starkware/cairo/common/memcpy.cairo b/src/starkware/cairo/common/memcpy.cairo index 09487ee5..cbf73e7d 100644 --- a/src/starkware/cairo/common/memcpy.cairo +++ b/src/starkware/cairo/common/memcpy.cairo @@ -9,13 +9,11 @@ func memcpy(dst : felt*, src : felt*, len): return () end - let frame = cast(ap, LoopFrame*) %{ vm_enter_scope({'n': ids.len}) %} - frame.dst = dst; ap++ - frame.src = src; ap++ + tempvar frame = LoopFrame(dst=dst, src=src) loop: - let frame = cast(ap - LoopFrame.SIZE, LoopFrame*) + let frame = [cast(ap - LoopFrame.SIZE, LoopFrame*)] assert [frame.dst] = [frame.src] let continue_copying = [ap] diff --git a/src/starkware/cairo/common/registers.cairo b/src/starkware/cairo/common/registers.cairo index be2758a1..302a94d7 100644 --- a/src/starkware/cairo/common/registers.cairo +++ b/src/starkware/cairo/common/registers.cairo @@ -10,15 +10,9 @@ end func get_ap() -> (ap_val): # Once get_ap() is invoked, fp points to ap + 2 (since the call instruction placed the old fp # and pc in memory, advancing ap accordingly). - # Calling dummy_func places fp and pc at [fp], [fp + 1] (respectively), and advances ap by 2. - # Hence, going two cells above we get [fp] = ap + 2, and by subtracting 2 we get the desired ap - # value. - call dummy_func - return (ap_val=[ap - 2] - 2) -end - -func dummy_func(): - return () + # Hence, the desired ap value is fp - 2. + let (fp_val, pc_val) = get_fp_and_pc() + return (ap_val=fp_val - 2) end # Takes the value of a label (relative to program base) and returns the actual runtime address of diff --git a/src/starkware/cairo/common/set.cairo b/src/starkware/cairo/common/set.cairo new file mode 100644 index 00000000..49b4883a --- /dev/null +++ b/src/starkware/cairo/common/set.cairo @@ -0,0 +1,58 @@ +from starkware.cairo.common.math import assert_nn_le +from starkware.cairo.common.memcpy import memcpy + +const SET_ADD_RANGE_CHECK_USAGE_ON_DUPLICATE = 2 +const SET_ADD_RANGE_CHECK_USAGE_ON_NO_DUPLICATE = 0 + +# Given an array of elements and an element, does one of two things: +# 1. Adds the element to the array. +# 2. Verifies that the element is in the array (all of the fields of the element are equal to all of +# the fields of one of the array's elements). +# +# Note that this function does not ensure that the elements of the resulted array are distinct +# (from soundness perspective). In other words, from the verifier's perspective, an element may be +# added even if it already existed. (On the other hand, from the prover perspective an element won't +# be added if it already exists). +# This function is usually used in order to avoid long arrays where it doesn't matter if an element +# exists more than once in the array. +# +# Arguments: +# set_ptr - pointer to an array. +# elm_size - size of an element in the array. +# elm_ptr - pointer to an element (of size elm_size) to add to the set. +# +# Implicit arguments: +# range_check_ptr - range check builtin pointer. +# set_end_ptr - pointer to the end of the array. +# +# Assumptions: +# elm_size != 0. +func set_add{range_check_ptr, set_end_ptr : felt*}(set_ptr : felt*, elm_size, elm_ptr : felt*): + alloc_locals + local is_elm_in_set + local index + %{ + assert ids.elm_size > 0 + assert ids.set_ptr <= ids.set_end_ptr + elm_list = memory.get_range(ids.elm_ptr, ids.elm_size) + for i in range(0, ids.set_end_ptr - ids.set_ptr, ids.elm_size): + if memory.get_range(ids.set_ptr + i, ids.elm_size) == elm_list: + ids.index = i // ids.elm_size + ids.is_elm_in_set = 1 + break + else: + ids.is_elm_in_set = 0 + %} + if is_elm_in_set != 0: + local located_elm_ptr : felt* = set_ptr + elm_size * index + # Using memcpy for equality assertion. + memcpy(dst=located_elm_ptr, src=elm_ptr, len=elm_size) + tempvar n_elms = (cast(set_end_ptr, felt) - cast(set_ptr, felt)) / elm_size + assert_nn_le(index, n_elms - 1) + return () + else: + memcpy(dst=set_end_ptr, src=elm_ptr, len=elm_size) + let set_end_ptr : felt* = set_end_ptr + elm_size + return () + end +end diff --git a/src/starkware/cairo/common/small_merkle_tree_test.py b/src/starkware/cairo/common/small_merkle_tree_test.py deleted file mode 100644 index ac79dae6..00000000 --- a/src/starkware/cairo/common/small_merkle_tree_test.py +++ /dev/null @@ -1,50 +0,0 @@ -import os - -from starkware.cairo.common.dict import DictManager -from starkware.cairo.common.small_merkle_tree import MerkleTree -from starkware.cairo.common.test_utils import CairoFunctionRunner -from starkware.cairo.lang.builtins.hash.hash_builtin_runner import CELLS_PER_HASH -from starkware.cairo.lang.compiler.cairo_compile import compile_cairo_files -from starkware.native_crypto.native_crypto import pedersen_hash - -CAIRO_FILE = os.path.join(os.path.dirname(__file__), 'small_merkle_tree.cairo') -PRIME = 2**251 + 17 * 2**192 + 1 -MERKLE_HEIGHT = 2 - - -def test_cairo_merkle_multi_update(): - program = compile_cairo_files([CAIRO_FILE], prime=PRIME, debug_info=True) - runner = CairoFunctionRunner(program) - - dict_manager = DictManager() - squashed_dict_start = dict_manager.new_dict( - segments=runner.segments, initial_dict={1: 10, 2: 20, 3: 30}) - - # Change the value at 1 from 10 to 11 and at 3 from 30 to 31. - squashed_dict = [1, 10, 11, 3, 30, 31] - squashed_dict_end = runner.segments.write_arg(ptr=squashed_dict_start, arg=squashed_dict) - dict_tracker = dict_manager.get_tracker(squashed_dict_start) - dict_tracker.current_ptr = squashed_dict_end - dict_tracker.data[1] = 11 - dict_tracker.data[3] = 31 - - runner.run( - 'small_merkle_tree', runner.hash_builtin.base, squashed_dict_start, squashed_dict_end, - MERKLE_HEIGHT, hint_locals=dict(__dict_manager=dict_manager)) - hash_ptr, prev_root, new_root = runner.get_return_values(3) - N_MERKLE_TREES = 2 - N_HASHES_PER_TREE = 3 - assert hash_ptr == \ - runner.hash_builtin.base + N_MERKLE_TREES * N_HASHES_PER_TREE * CELLS_PER_HASH - assert prev_root == pedersen_hash(pedersen_hash(0, 10), pedersen_hash(20, 30)) - assert new_root == pedersen_hash(pedersen_hash(0, 11), pedersen_hash(20, 31)) - - -def test_merkle_tree(): - tree = MerkleTree(tree_height=2, default_leaf=10) - expected_hash = pedersen_hash(pedersen_hash(10, 10), pedersen_hash(10, 10)) - assert tree.compute_merkle_root([]) == expected_hash - # Change leaf 1 to 7. - expected_hash = pedersen_hash(pedersen_hash(10, 7), pedersen_hash(10, 10)) - assert tree.compute_merkle_root([(1, 7)]) == expected_hash - assert tree.compute_merkle_root([]) == expected_hash diff --git a/src/starkware/cairo/common/squash_dict.cairo b/src/starkware/cairo/common/squash_dict.cairo index 78916749..2f055c7c 100644 --- a/src/starkware/cairo/common/squash_dict.cairo +++ b/src/starkware/cairo/common/squash_dict.cairo @@ -38,12 +38,19 @@ func squash_dict{range_check_ptr}( ap += 2 tempvar n_accesses = ptr_diff / DictAccess.SIZE %{ - assert ids.ptr_diff % ids.DictAccess.SIZE == 0, \ + dict_access_size = ids.DictAccess.SIZE + address = ids.dict_accesses.address_ + assert ids.ptr_diff % dict_access_size == 0, \ 'Accesses array size must be divisible by DictAccess.SIZE' + n_accesses = ids.n_accesses + if '__squash_dict_max_size' in globals(): + assert n_accesses <= __squash_dict_max_size, \ + f'squash_dict() can only be used with n_accesses<={__squash_dict_max_size}. ' \ + f'Got: n_accesses={n_accesses}.' # A map from key to the list of indices accessing it. access_indices = {} - for i in range(ids.n_accesses): - key = memory[ids.dict_accesses.address_ + ids.DictAccess.SIZE * i] + for i in range(n_accesses): + key = memory[address + dict_access_size * i] access_indices.setdefault(key, []).append(i) # Descending list of keys. keys = sorted(access_indices.keys(), reverse=True) @@ -139,7 +146,7 @@ func squash_dict_inner( local first_value = first_access.prev_value assert first_value = dict_diff.prev_value - # Skip loop non-deterministically if necessary. + # Skip loop nondeterministically if necessary. local should_skip_loop %{ ids.should_skip_loop = 0 if current_access_indices else 1 %} jmp skip_loop if should_skip_loop != 0 diff --git a/src/starkware/cairo/common/structs.py b/src/starkware/cairo/common/structs.py new file mode 100644 index 00000000..504adba9 --- /dev/null +++ b/src/starkware/cairo/common/structs.py @@ -0,0 +1,115 @@ +from collections import namedtuple +from typing import List, MutableMapping, Optional + +from starkware.cairo.lang.compiler.ast.code_elements import CodeElementFunction +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager +from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition +from starkware.cairo.lang.compiler.program import Program +from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.python.utils import WriteOnceDict + + +class CairoStructFactory: + def __init__( + self, identifiers: IdentifierManager, additional_imports: Optional[List[str]] = None): + """ + Creates a CairoStructFactory that converts Cairo structs to python namedtuples. + + identifiers - an identifier manager holding the structs. + additional_imports - An optional list of fully qualified names of structs to preload. + Useful for importing absolute paths, rather than relative. + """ + self.identifiers = identifiers + + self.resolved_identifiers: MutableMapping[ScopedName, ScopedName] = WriteOnceDict() + if additional_imports is not None: + for identifier_path in additional_imports: + scope_name = ScopedName.from_string(identifier_path) + # Call get_struct_definition to make sure scope_name is a struct. + get_struct_definition( + struct_name=scope_name, + identifier_manager=identifiers) + self.resolved_identifiers[scope_name[-1:]] = scope_name + + @classmethod + def from_program(cls, program: Program, additional_imports: Optional[List[str]] = None): + return cls(identifiers=program.identifiers, additional_imports=additional_imports) + + def _get_full_name(self, name: ScopedName): + full_name = self.resolved_identifiers.get(name) + if full_name is not None: + return full_name + + return self.identifiers.search( + accessible_scopes=[ScopedName.from_string('__main__'), ScopedName()], + name=name).get_canonical_name() + + def build_struct(self, name: ScopedName): + """ + Builds and returns namedtuple from a Cairo struct. + """ + full_name = self._get_full_name(name) + members = get_struct_definition(full_name, self.identifiers).members + return namedtuple(full_name.path[-1], list(members.keys())) + + def get_struct_size(self, name: ScopedName) -> int: + """ + Returns the size of the given struct. + """ + full_name = self._get_full_name(name) + return get_struct_definition(full_name, self.identifiers).size + + def build_func_args(self, func: ScopedName): + """ + Builds a namedtuple that contains both the explicit and the implicit arguments of 'func'. + """ + full_name = self._get_full_name(func) + + implict_args = get_struct_definition( + full_name + CodeElementFunction.IMPLICIT_ARGUMENT_SCOPE, + self.identifiers).members + args = get_struct_definition( + full_name + CodeElementFunction.ARGUMENT_SCOPE, self.identifiers).members + return namedtuple(f'{func[-1:]}_full_args', list({**implict_args, **args})) + + @property + def structs(self): + """ + Dynamic namespace of all available structs. For example, to get the namedtuple of + a.b.MyStruct, use cairo_struct_factory.struct.a.b.MyStruct. + """ + return CairoStructProxy(self, ScopedName()) + + +class CairoStructProxy: + """ + Helper class for CairoStructFactory. See CairoStructFactory.structs. + """ + + def __init__(self, factory: CairoStructFactory, path: ScopedName): + self.factory = factory + self.path = path + + def __getattr__(self, name: str) -> 'CairoStructProxy': + return CairoStructProxy(self.factory, self.path + name) + + def build(self): + return self.factory.build_struct(self.path) + + def __call__(self, *args, **kwargs): + return self.build()(*args, **kwargs) + + @property + def size(self): + return self.factory.get_struct_size(self.path) + + def from_ptr(self, runner, addr): + """ + Interprets addr as a pointer to a struct of type path and creates the corresponding + namedtuple instance. + """ + named_tuple = self.build() + + return named_tuple(**{ + name: runner.vm_memory[addr + index] + for index, name in enumerate(named_tuple._fields)}) diff --git a/src/starkware/cairo/common/uint256.cairo b/src/starkware/cairo/common/uint256.cairo new file mode 100644 index 00000000..5eec5b30 --- /dev/null +++ b/src/starkware/cairo/common/uint256.cairo @@ -0,0 +1,83 @@ +from starkware.cairo.common.math import assert_nn_le, assert_not_zero +from starkware.cairo.common.math_cmp import is_le + +# Represents an integer in the range [0, 2^256). +struct Uint256: + # The low 128 bits of the value. + member low : felt + # The high 128 bits of the value. + member high : felt +end + +const SHIFT = %[2 ** 128%] +const HALF_SHIFT = %[2 ** 64%] + +# Verifies that the given integer is valid. +func uint256_check{range_check_ptr}(a : Uint256): + [range_check_ptr] = a.low + [range_check_ptr + 1] = a.high + let range_check_ptr = range_check_ptr + 2 + return () +end + +# Arithmetics. + +# Adds two integers. Returns the result as a 256-bit integer and the (1-bit) carry. +func uint256_add{range_check_ptr}(a : Uint256, b : Uint256) -> (res : Uint256, carry : felt): + alloc_locals + local res : Uint256 + local carry_low : felt + local carry_high : felt + %{ + sum_low = ids.a.low + ids.b.low + ids.carry_low = 1 if sum_low >= ids.SHIFT else 0 + sum_high = ids.a.high + ids.b.high + ids.carry_low + ids.carry_high = 1 if sum_high >= ids.SHIFT else 0 + %} + + assert carry_low * carry_low = carry_low + assert carry_high * carry_high = carry_high + + assert res.low = a.low + b.low - carry_low * SHIFT + assert res.high = a.high + b.high + carry_low - carry_high * SHIFT + uint256_check(res) + + return (res, carry_high) +end + +# Splits a field element in the range [0, 2^192) to its low 64-bit and high 128-bit parts. +func split_64{range_check_ptr}(a : felt) -> (low : felt, high : felt): + alloc_locals + local low : felt + local high : felt + + %{ + ids.low = ids.a & ((1<<64) - 1) + ids.high = ids.a >> 64 + %} + assert_nn_le(low, HALF_SHIFT) + [range_check_ptr] = high + let range_check_ptr = range_check_ptr + 1 + return (low, high) +end + +# Multiplies two integers. Returns the result as two 256-bit integers (low and high parts). +func uint256_mul{range_check_ptr}(a : Uint256, b : Uint256) -> (low : Uint256, high : Uint256): + alloc_locals + let (a0, a1) = split_64(a.low) + let (a2, a3) = split_64(a.high) + let (b0, b1) = split_64(b.low) + let (b2, b3) = split_64(b.high) + + let (res0, carry) = split_64(a0 * b0) + let (res1, carry) = split_64(a1 * b0 + a0 * b1 + carry) + let (res2, carry) = split_64(a2 * b0 + a1 * b1 + a0 * b2 + carry) + let (res3, carry) = split_64(a3 * b0 + a2 * b1 + a1 * b2 + a0 * b3 + carry) + let (res4, carry) = split_64(a3 * b1 + a2 * b2 + a1 * b3 + carry) + let (res5, carry) = split_64(a3 * b2 + a2 * b3 + carry) + let (res6, carry) = split_64(a3 * b3 + carry) + + return ( + low=Uint256(low=res0 + HALF_SHIFT * res1, high=res2 + HALF_SHIFT * res3), + high=Uint256(low=res4 + HALF_SHIFT * res5, high=res6 + HALF_SHIFT * carry)) +end diff --git a/src/starkware/cairo/lang/CMakeLists.txt b/src/starkware/cairo/lang/CMakeLists.txt index d49a6a92..c5a76d78 100644 --- a/src/starkware/cairo/lang/CMakeLists.txt +++ b/src/starkware/cairo/lang/CMakeLists.txt @@ -4,57 +4,4 @@ add_subdirectory(scripts) add_subdirectory(tracer) add_subdirectory(vm) -python_lib(cairo_version_lib - PREFIX starkware/cairo/lang - - FILES - VERSION - version.py -) - -if (NOT DEFINED CAIRO_PYTHON_INTERPRETER) - set(CAIRO_PYTHON_INTERPRETER python3.7) -endif() - -python_venv(cairo_lang_venv - PYTHON ${CAIRO_PYTHON_INTERPRETER} - LIBS - cairo_bootloader_generate_fact_lib - cairo_common_lib - cairo_compile_lib - cairo_hash_program_lib - cairo_run_lib - cairo_script_lib - ${CAIRO_LANG_VENV_ADDITIONAL_LIBS} -) - -python_venv(cairo_lang_package_venv - PYTHON python3.7 - LIBS - cairo_bootloader_generate_fact_lib - cairo_common_lib - cairo_compile_lib - cairo_hash_program_lib - cairo_run_lib - cairo_script_lib - sharp_client_lib - sharp_client_config_lib -) - -python_lib(cairo_instances_lib - PREFIX starkware/cairo/lang - - FILES - instances.py - ${CAIRO_INSTANCES_LIB_ADDITIONAL_FILES} - - LIBS - cairo_run_builtins_lib -) - -python_lib(cairo_constants_lib - PREFIX starkware/cairo/lang - - FILES - cairo_constants.py -) +include(lang.cmake) diff --git a/src/starkware/cairo/lang/VERSION b/src/starkware/cairo/lang/VERSION index 6e8bf73a..0ea3a944 100644 --- a/src/starkware/cairo/lang/VERSION +++ b/src/starkware/cairo/lang/VERSION @@ -1 +1 @@ -0.1.0 +0.2.0 diff --git a/src/starkware/cairo/lang/builtins/CMakeLists.txt b/src/starkware/cairo/lang/builtins/CMakeLists.txt index f59e9224..68b0df4e 100644 --- a/src/starkware/cairo/lang/builtins/CMakeLists.txt +++ b/src/starkware/cairo/lang/builtins/CMakeLists.txt @@ -2,8 +2,6 @@ python_lib(cairo_run_builtins_lib PREFIX starkware/cairo/lang/builtins FILES - checkpoints/instance_def.py - checkpoints/checkpoints_builtin_runner.py hash/hash_builtin_runner.py hash/instance_def.py range_check/instance_def.py @@ -12,7 +10,7 @@ python_lib(cairo_run_builtins_lib signature/signature_builtin_runner.py LIBS - cairo_relocatable + cairo_relocatable_lib cairo_vm_lib starkware_python_utils_lib ) diff --git a/src/starkware/cairo/lang/builtins/checkpoints/checkpoints_builtin_runner.py b/src/starkware/cairo/lang/builtins/checkpoints/checkpoints_builtin_runner.py deleted file mode 100644 index aaff0a47..00000000 --- a/src/starkware/cairo/lang/builtins/checkpoints/checkpoints_builtin_runner.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import List - -from starkware.cairo.lang.builtins.checkpoints.instance_def import CELLS_PER_SAMPLE -from starkware.cairo.lang.vm.builtin_runner import SimpleBuiltinRunner -from starkware.python.math_utils import safe_div - - -class CheckpointsBuiltinRunner(SimpleBuiltinRunner): - def __init__(self, name: str, included: bool, sample_ratio: int): - self.sample_ratio = sample_ratio - self.samples: List = [] - super().__init__(name, included, sample_ratio, CELLS_PER_SAMPLE) - - def finalize_segments(self, runner): - memory = runner.vm.run_context.memory - memory[self.stop_ptr] = 0 - memory[self.stop_ptr + 1] = 0 - super().finalize_segments(runner) - - def get_used_cells_and_allocated_size(self, runner): - size = self.get_used_cells(runner) - return size, size - - def sample(self, step, pc, fp): - self.samples.append((step, pc, fp)) - - def relocate(self, relocate_value): - self.samples = [tuple(map(relocate_value, sample)) for sample in self.samples] - - def air_private_input(self, runner): - return {self.name: [ - { - 'index': safe_div(step, self.sample_ratio), - 'pc': hex(pc), - 'fp': hex(fp) - } - for step, pc, fp in self.samples]} diff --git a/src/starkware/cairo/lang/builtins/checkpoints/instance_def.py b/src/starkware/cairo/lang/builtins/checkpoints/instance_def.py deleted file mode 100644 index eac4489f..00000000 --- a/src/starkware/cairo/lang/builtins/checkpoints/instance_def.py +++ /dev/null @@ -1,11 +0,0 @@ -import dataclasses - -# Each sample consists of 2 cells (required pc and required fp). -CELLS_PER_SAMPLE = 2 - - -@dataclasses.dataclass -class CheckpointsInstanceDef: - # Defines the ratio between the number of steps to the number of samples. - # For every sample_ratio steps, we have one sample. - sample_ratio: int diff --git a/src/starkware/cairo/lang/compiler/CMakeLists.txt b/src/starkware/cairo/lang/compiler/CMakeLists.txt index 13619b8f..b36aebe4 100644 --- a/src/starkware/cairo/lang/compiler/CMakeLists.txt +++ b/src/starkware/cairo/lang/compiler/CMakeLists.txt @@ -10,6 +10,7 @@ python_lib(cairo_compile_lib ast/cairo_types.py ast/code_elements.py ast/expr.py + ast/expr_func_call.py ast/formatting_utils.py ast/instructions.py ast/module.py @@ -43,11 +44,14 @@ python_lib(cairo_compile_lib parser_transformer.py parser.py preprocessor/compound_expressions.py + preprocessor/default_pass_manager.py preprocessor/dependency_graph.py preprocessor/flow.py preprocessor/identifier_aware_visitor.py preprocessor/identifier_collector.py preprocessor/local_variables.py + preprocessor/pass_manager.py + preprocessor/preprocess_codes.py preprocessor/preprocessor_error.py preprocessor/preprocessor_utils.py preprocessor/preprocessor.py @@ -56,16 +60,19 @@ python_lib(cairo_compile_lib preprocessor/unique_labels.py program.py references.py + resolve_search_result.py scoped_name.py substitute_identifiers.py type_casts.py type_system_visitor.py + type_system.py LIBS cairo_constants_lib cairo_version_lib starkware_expression_string_lib starkware_python_utils_lib + starkware_utils_lib pip_marshmallow_dataclass pip_marshmallow_enum pip_marshmallow_oneofschema @@ -92,6 +99,7 @@ python_exe(cairo_format python_lib(cairo_compile_test_utils_lib PREFIX starkware/cairo/lang/compiler FILES + ast/ast_objects_test_utils.py preprocessor/preprocessor_test_utils.py test_utils.py @@ -110,6 +118,7 @@ full_python_test(cairo_compile_test ast_objects_test.py ast/formatting_utils_test.py cairo_compile_test.py + conftest.py encode_test.py error_handling_test.py expression_evaluator_test.py @@ -124,10 +133,13 @@ full_python_test(cairo_compile_test module_reader_test.py offset_reference_test.py parser_errors_test.py + parser_test_utils.py parser_test.py preprocessor/compound_expressions_test.py + preprocessor/conftest.py preprocessor/dependency_graph_test.py preprocessor/flow_test.py + preprocessor/identifier_aware_visitor_test.py preprocessor/identifier_collector_test.py preprocessor/local_variables_test.py preprocessor/preprocessor_test.py @@ -135,6 +147,7 @@ full_python_test(cairo_compile_test preprocessor/struct_collector_test.py preprocessor/unique_labels_test.py references_test.py + resolve_search_result_test.py scoped_name_test.py type_casts_test.py type_system_visitor_test.py @@ -142,5 +155,6 @@ full_python_test(cairo_compile_test LIBS cairo_compile_lib cairo_compile_test_utils_lib + starkware_python_utils_lib pip_pytest ) diff --git a/src/starkware/cairo/lang/compiler/assembler_test.py b/src/starkware/cairo/lang/compiler/assembler_test.py index c8ed75b0..d0dc430e 100644 --- a/src/starkware/cairo/lang/compiler/assembler_test.py +++ b/src/starkware/cairo/lang/compiler/assembler_test.py @@ -20,7 +20,7 @@ def test_main_scope(): identifiers=identifiers, reference_manager=reference_manager) # Check accessible identifiers. - assert program.get_identifier('b', ConstDefinition) + assert program.get_identifier('b', ConstDefinition).value == 1 # Ensure inaccessible identifiers. with pytest.raises(MissingIdentifierError, match="Unknown identifier 'a'."): @@ -32,6 +32,10 @@ def test_main_scope(): with pytest.raises(MissingIdentifierError, match="Unknown identifier 'y'."): program.get_identifier('y', ConstDefinition) + # Full name lookup. + assert program.get_identifier('a.b', ConstDefinition, full_name_lookup=True).value == 1 + assert program.get_identifier('x.y.z', ConstDefinition, full_name_lookup=True).value == 2 + def test_program_start_property(): identifiers = IdentifierManager.from_dict({ diff --git a/src/starkware/cairo/lang/compiler/ast/ast_objects_test_utils.py b/src/starkware/cairo/lang/compiler/ast/ast_objects_test_utils.py new file mode 100644 index 00000000..a11d9815 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/ast_objects_test_utils.py @@ -0,0 +1,24 @@ +from starkware.cairo.lang.compiler.ast.expr import ( + ExprAddressOf, ExprDeref, ExprDot, ExprNeg, ExprOperator, ExprParentheses, ExprSubscript) + + +def remove_parentheses(expr): + """ + Removes the parentheses (ExprParentheses) from an arithmetic expression. + """ + if isinstance(expr, ExprParentheses): + return remove_parentheses(expr.val) + if isinstance(expr, ExprOperator): + return ExprOperator(a=remove_parentheses(expr.a), op=expr.op, b=remove_parentheses(expr.b)) + if isinstance(expr, ExprAddressOf): + return ExprAddressOf(expr=remove_parentheses(expr.expr)) + if isinstance(expr, ExprNeg): + return ExprNeg(val=remove_parentheses(expr.val)) + if isinstance(expr, ExprDeref): + return ExprDeref(addr=remove_parentheses(expr.addr)) + if isinstance(expr, ExprDot): + return ExprDot(expr=remove_parentheses(expr.expr), member=expr.member) + if isinstance(expr, ExprSubscript): + return ExprSubscript( + expr=remove_parentheses(expr.expr), offset=remove_parentheses(expr.offset)) + return expr diff --git a/src/starkware/cairo/lang/compiler/ast/code_elements.py b/src/starkware/cairo/lang/compiler/ast/code_elements.py index c2050f40..82a40986 100644 --- a/src/starkware/cairo/lang/compiler/ast/code_elements.py +++ b/src/starkware/cairo/lang/compiler/ast/code_elements.py @@ -1,11 +1,11 @@ import dataclasses from abc import abstractmethod -from typing import List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence from starkware.cairo.lang.compiler.ast.aliased_identifier import AliasedIdentifier from starkware.cairo.lang.compiler.ast.arguments import IdentifierList from starkware.cairo.lang.compiler.ast.bool_expr import BoolExpr -from starkware.cairo.lang.compiler.ast.expr import ArgListItem, Expression, ExprIdentifier +from starkware.cairo.lang.compiler.ast.expr import ExprAssignment, Expression, ExprIdentifier from starkware.cairo.lang.compiler.ast.formatting_utils import ( INDENTATION, LocationField, ParticleFormattingConfig, create_particle_sublist, particles_in_lines) @@ -150,7 +150,7 @@ class CodeElementReturn(CodeElement): Represents a statement of the form: return ([ident=]expr, ...). """ - exprs: List[ArgListItem] + exprs: List[ExprAssignment] location: Optional[Location] = LocationField def format(self, allowed_line_length): @@ -391,6 +391,8 @@ class CodeElementFunction(CodeElement): implicit_arguments: Optional[IdentifierList] returns: Optional[IdentifierList] code_block: CodeBlock + decorators: List[ExprIdentifier] + additional_attributes: Dict[str, Any] = dataclasses.field(default_factory=dict) ARGUMENT_SCOPE = ScopedName.from_string('Args') IMPLICIT_ARGUMENT_SCOPE = ScopedName.from_string('ImplicitArgs') @@ -426,12 +428,13 @@ def format(self, allowed_line_length): *implicit_args_particles, create_particle_sublist(self.arguments.get_particles(), '):')] + decorators = ''.join(f'@{decorator.format()}\n' for decorator in self.decorators) header = particles_in_lines( particles=particles, config=ParticleFormattingConfig( allowed_line_length=allowed_line_length, line_indent=INDENTATION * 2)) - return f'{header}\n{code}end' + return f'{decorators}{header}\n{code}end' def get_children(self) -> Sequence[Optional[AstNode]]: return [ @@ -505,6 +508,18 @@ def get_children(self) -> Sequence[Optional[AstNode]]: return [] +@dataclasses.dataclass +class LangDirective(Directive): + name: str + location: Optional[Location] = LocationField + + def format(self): + return f'%lang {self.name}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + @dataclasses.dataclass class CodeElementDirective(CodeElement): directive: Directive diff --git a/src/starkware/cairo/lang/compiler/ast/expr.py b/src/starkware/cairo/lang/compiler/ast/expr.py index 077b21a5..9db93916 100644 --- a/src/starkware/cairo/lang/compiler/ast/expr.py +++ b/src/starkware/cairo/lang/compiler/ast/expr.py @@ -76,20 +76,8 @@ def get_children(self) -> Sequence[Optional[AstNode]]: return [] -class ArgListItem(AstNode): - """ - Represents an item in function call or return statement. - """ - - location: Optional[Location] - - @abstractmethod - def format(self): - pass - - @dataclasses.dataclass -class ExprAssignment(ArgListItem): +class ExprAssignment(AstNode): """ A code element of the form [ident=]expr. The identifier is optional. """ @@ -112,7 +100,7 @@ class ArgList(AstNode): Represents a list of arguments (e.g., to a function call or a return statement). For example: 'a=1, b=2'. """ - args: List[ArgListItem] + args: List[ExprAssignment] notes: List[Notes] has_trailing_comma: bool location: Optional[Location] = LocationField @@ -229,7 +217,7 @@ def get_children(self) -> Sequence[Optional[AstNode]]: @dataclasses.dataclass class ExprDeref(Expression): """ - Represents an expression of the form "[expr]". + Represents an expression of the form "[addr]". """ addr: Expression notes: Notes = NotesField @@ -244,6 +232,45 @@ def get_children(self) -> Sequence[Optional[AstNode]]: return [self.addr] +@dataclasses.dataclass +class ExprSubscript(Expression): + """ + Represents an expression of the form "expr[offset]". + """ + expr: Expression + offset: Expression + notes: Notes = NotesField + location: Optional[Location] = LocationField + + def to_expr_str(self): + self.notes.assert_no_comments() + notes = '' if self.notes.empty else '\n' + # If expr is not an atom, add parentheses. + return ExpressionString.highest( + f'{self.expr.to_expr_str():HIGHEST}[{notes}{str(self.offset.to_expr_str())}]') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.expr, self.offset] + + +@dataclasses.dataclass +class ExprDot(Expression): + """ + Represents an expression of the form "expr.member". + """ + expr: Expression + member: ExprIdentifier + location: Optional[Location] = LocationField + + def to_expr_str(self): + # If expr is not an atom, add parentheses. + return ExpressionString.highest( + f'{self.expr.to_expr_str():HIGHEST}.{str(self.member.to_expr_str())}') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.expr, self.member] + + @dataclasses.dataclass class ExprCast(Expression): """ diff --git a/src/starkware/cairo/lang/compiler/ast/expr_func_call.py b/src/starkware/cairo/lang/compiler/ast/expr_func_call.py new file mode 100644 index 00000000..c353f283 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/expr_func_call.py @@ -0,0 +1,25 @@ +import dataclasses +from typing import Optional, Sequence + +from starkware.cairo.lang.compiler.ast.expr import Expression +from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.ast.rvalue import RvalueFuncCall +from starkware.cairo.lang.compiler.error_handling import Location +from starkware.python.expression_string import ExpressionString + + +@dataclasses.dataclass +class ExprFuncCall(Expression): + """ + Represents an expression of the form "()". For example, "foo(1, 2, z=3)". + """ + + rvalue: RvalueFuncCall + location: Optional[Location] = LocationField + + def to_expr_str(self): + return ExpressionString.highest(self.rvalue.format_for_expr()) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.rvalue] diff --git a/src/starkware/cairo/lang/compiler/ast/rvalue.py b/src/starkware/cairo/lang/compiler/ast/rvalue.py index 922d1f3d..d5fe3967 100644 --- a/src/starkware/cairo/lang/compiler/ast/rvalue.py +++ b/src/starkware/cairo/lang/compiler/ast/rvalue.py @@ -132,5 +132,19 @@ def format(self, allowed_line_length): line_indent=INDENTATION, one_per_line=True)) + def format_for_expr(self) -> str: + """ + Formats the rvalue without automatic line breaking. + Should be used when the rvalue is part of an expression (where the line breaking mechanism + is not supported yet). + """ + res = self.func_ident.format() + + if self.implicit_arguments is not None: + res += '{' + self.implicit_arguments.format() + '}' + + res += '(' + self.arguments.format() + ')' + return res + def get_children(self) -> Sequence[Optional[AstNode]]: return [self.func_ident, self.arguments, self.implicit_arguments] diff --git a/src/starkware/cairo/lang/compiler/ast/visitor.py b/src/starkware/cairo/lang/compiler/ast/visitor.py index be2a9fa8..01661126 100644 --- a/src/starkware/cairo/lang/compiler/ast/visitor.py +++ b/src/starkware/cairo/lang/compiler/ast/visitor.py @@ -2,12 +2,18 @@ from typing import List, Optional from starkware.cairo.lang.compiler.ast.code_elements import ( - CodeBlock, CodeElementFunction, CodeElementScoped, CodeElementWith, CommentedCodeElement) + CodeBlock, CodeElementDirective, CodeElementFunction, CodeElementScoped, CodeElementWith, + CommentedCodeElement, LangDirective) from starkware.cairo.lang.compiler.ast.module import CairoFile, CairoModule from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.error_handling import LocationError from starkware.cairo.lang.compiler.scoped_name import ScopedName +class VisitorError(LocationError): + pass + + class Visitor: """ Base visitor class for visiting code elements in the Cairo AST. @@ -16,6 +22,7 @@ class Visitor: def __init__(self): self.accessible_scopes: List[ScopedName] = [] self.parents: List[Optional[AstNode]] = [] + self.file_lang: Optional[str] = None def visit(self, obj): """ @@ -37,10 +44,12 @@ def visit_CodeElementFunction(self, elm: CodeElementFunction): implicit_arguments=elm.implicit_arguments, returns=elm.returns, code_block=self.visit(elm.code_block), + decorators=elm.decorators, ) def visit_CairoModule(self, module: CairoModule): - with self.scoped(module.module_name, parent=module): + with self.scoped(module.module_name, parent=module), \ + self.with_file_lang(get_lang_from_file(module.cairo_file)): return CairoModule( cairo_file=CairoFile(code_block=self.visit(module.cairo_file.code_block)), module_name=module.module_name, @@ -66,7 +75,8 @@ def _visit_default(self, obj): """ Default behavior for visitor if 'obj' type isn't handled. By default, raise exception. """ - raise NotImplementedError(f'No handler found for type {type(obj).__name__}.') + raise NotImplementedError( + f'No handler found for type {type(obj).__name__} in {type(self).__name__}.') @contextmanager def scoped(self, new_scope: ScopedName, parent: Optional[AstNode]): @@ -81,6 +91,17 @@ def scoped(self, new_scope: ScopedName, parent: Optional[AstNode]): self.accessible_scopes.pop() self.parents.pop() + @contextmanager + def with_file_lang(self, lang: Optional[str]): + """ + Context manager for setting the file_lang member. + """ + old_file_lang, self.file_lang = self.file_lang, lang + try: + yield + finally: + self.file_lang = old_file_lang + @property def current_scope(self) -> ScopedName: """ @@ -88,3 +109,22 @@ def current_scope(self) -> ScopedName: """ assert len(self.accessible_scopes) > 0 return self.accessible_scopes[-1] + + +def get_lang_from_file(cairo_file: CairoFile) -> Optional[str]: + """ + Returns the value of the %lang directive if it exists in the given file. + Returns None otherwise. + """ + lang = None + for commented_code_element in cairo_file.code_block.code_elements: + code_elm = commented_code_element.code_elm + if not isinstance(code_elm, CodeElementDirective): + continue + directive = code_elm.directive + if not isinstance(directive, LangDirective): + continue + if lang is not None: + raise VisitorError('Found two %lang directives', location=code_elm.location) + lang = directive.name + return lang diff --git a/src/starkware/cairo/lang/compiler/ast_objects_test.py b/src/starkware/cairo/lang/compiler/ast_objects_test.py index 955a51ae..2746713b 100644 --- a/src/starkware/cairo/lang/compiler/ast_objects_test.py +++ b/src/starkware/cairo/lang/compiler/ast_objects_test.py @@ -1,26 +1,11 @@ import pytest -from starkware.cairo.lang.compiler.ast.expr import ( - ExprAddressOf, ExprConst, ExprNeg, ExprOperator, ExprParentheses) +from starkware.cairo.lang.compiler.ast.ast_objects_test_utils import remove_parentheses +from starkware.cairo.lang.compiler.ast.expr import ExprConst, ExprNeg, ExprOperator from starkware.cairo.lang.compiler.ast.formatting_utils import FormattingError from starkware.cairo.lang.compiler.parser import parse_code_element, parse_expr, parse_file -def remove_parentheses(expr): - """ - Removes the parentheses (ExprParentheses) from an arithmetic expression. - """ - if isinstance(expr, ExprParentheses): - return expr.val - if isinstance(expr, ExprOperator): - return ExprOperator(a=remove_parentheses(expr.a), op=expr.op, b=remove_parentheses(expr.b)) - if isinstance(expr, ExprAddressOf): - return ExprAddressOf(expr=remove_parentheses(expr.expr)) - if isinstance(expr, ExprNeg): - return ExprNeg(val=remove_parentheses(expr.val)) - return expr - - def test_format_parentheses(): """ Tests that format() adds parentheses where required. @@ -34,16 +19,37 @@ def test_format_parentheses(): 'x - (a + b) - (c - d) - e * f' assert remove_parentheses(parse_expr('(a + b) + (c - d) + (e * f)')).format() == \ 'a + b + c - d + e * f' - assert remove_parentheses(parse_expr('-(a + b + c)')).format() == \ - '-(a + b + c)' - assert remove_parentheses(parse_expr('a + -b + c')).format() == \ - 'a + (-b) + c' - assert remove_parentheses(parse_expr('&(a + b)')).format() == \ - '&(a + b)' + assert remove_parentheses(parse_expr('-(a + b + c)')).format() == '-(a + b + c)' + assert remove_parentheses(parse_expr('a + -b + c')).format() == 'a + (-b) + c' + assert remove_parentheses(parse_expr('&(a + b)')).format() == '&(a + b)' + + # Test that parentheses are added to non-atomized Dot and Subscript expressions. + assert remove_parentheses(parse_expr('(x * y).z')).format() == '(x * y).z' + assert remove_parentheses(parse_expr('(-x).y')).format() == '(-x).y' + assert remove_parentheses(parse_expr('(&x).y')).format() == '(&x).y' + assert remove_parentheses(parse_expr('(x * y)[z]')).format() == '(x * y)[z]' + assert remove_parentheses(parse_expr('(-x)[y]')).format() == '(-x)[y]' + assert remove_parentheses(parse_expr('(&x)[y]')).format() == '(&x)[y]' + + assert remove_parentheses(parse_expr('&(x.y)')).format() == '&x.y' + assert remove_parentheses(parse_expr('-(x.y)')).format() == '-x.y' + assert remove_parentheses(parse_expr('(x.y)*z')).format() == 'x.y * z' + assert remove_parentheses(parse_expr('x-(y.z)')).format() == 'x - y.z' + + assert remove_parentheses(parse_expr('([x].y).z')).format() == '[x].y.z' + assert remove_parentheses(parse_expr('&(x[y])')).format() == '&x[y]' + assert remove_parentheses(parse_expr('-(x[y])')).format() == '-x[y]' + assert remove_parentheses(parse_expr('(x[y])*z')).format() == 'x[y] * z' + assert remove_parentheses(parse_expr('x-(y[z])')).format() == 'x - y[z]' + assert remove_parentheses(parse_expr('(([x][y])[z])')).format() == '[x][y][z]' + assert remove_parentheses(parse_expr('x[(y+z)]')).format() == 'x[y + z]' + + assert remove_parentheses(parse_expr('[((x+y) + z)]')).format() == '[x + y + z]' # Test that parentheses are not added if they were already present. assert parse_expr('(a * (b + c))').format() == '(a * (b + c))' assert parse_expr('((a * ((b + c))))').format() == '((a * ((b + c))))' + assert parse_expr('(x + y)[z]').format() == '(x + y)[z]' def test_format_parentheses_notes(): @@ -110,8 +116,11 @@ def test_file_format(): before = """ ap+=[ fp ] +%lang starknet [ap + -1] = [fp] * 3 - const x=y + z + const x=y + f(a=g( + z) ,# test + b=0) member x:T.S let x= ap-y + z let y:a.b.c= x @@ -136,8 +145,11 @@ def test_file_format(): [fp] = [fp] * [fp]""" after = """\ ap += [fp] +%lang starknet [ap + (-1)] = [fp] * 3 -const x = y + z +const x = y + f(a=g( + z), # test + b=0) member x : T.S let x = ap - y + z let y : a.b.c = x diff --git a/src/starkware/cairo/lang/compiler/cairo.ebnf b/src/starkware/cairo/lang/compiler/cairo.ebnf index 8e45f218..a339f23b 100644 --- a/src/starkware/cairo/lang/compiler/cairo.ebnf +++ b/src/starkware/cairo/lang/compiler/cairo.ebnf @@ -9,6 +9,7 @@ _DBL_PLUS: "++" _DBL_EQ: "==" _NEQ: "!=" _ARROW: "->" +_AT: "@" // Types. type: "felt" -> type_felt @@ -21,6 +22,9 @@ expr_assignment: expr | identifier_def "=" expr ?arg_list_item: expr_assignment arg_list: ((notes arg_list_item notes ",")* notes arg_list_item (notes ",")?)? notes +?decorator: _AT identifier_def +decorator_list: (decorator _NEWLINE*)* + ?expr: sum ?sum: product | sum "+" notes product -> expr_add @@ -36,9 +40,12 @@ identifier_def: IDENTIFIER ?atom: INT -> atom_number | PYCONST -> atom_pyconst | reg -> atom_reg + | function_call -> atom_func_call | identifier | "(" notes expr ")" -> atom_parentheses | "[" notes expr "]" -> atom_deref + | atom "[" notes expr "]" -> atom_subscript + | atom "." identifier_def -> atom_dot | "cast" "(" notes expr "," type ")" -> atom_cast | "(" arg_list ")" -> atom_tuple !reg: "ap" -> reg_ap @@ -76,12 +83,13 @@ instruction: instruction_body -> instruction_noap function_call: identifier ("{" arg_list "}")? "(" arg_list ")" // Reference expressions. -?rvalue: expr -> rvalue_expr - | call_instruction -> rvalue_call_instruction - | function_call +?rvalue: call_instruction -> rvalue_call_instruction + | function_call + | expr -> rvalue_expr // Directives. directive: "%builtins" identifier+ -> directive_builtins + | "%lang" identifier -> directive_lang // Import statement. aliased_identifier: identifier_def ("as" identifier_def)? @@ -96,7 +104,7 @@ _returns: _ARROW _NEWLINE* "(" identifier_list ")" _arguments: "(" identifier_list ")" implicit_arguments: ("{" identifier_list "}")? _funcdecl: "func" identifier_def implicit_arguments _arguments _NEWLINE* _returns? ":" -_func: _funcdecl _NEWLINE code_block "end" +_func: decorator_list _funcdecl _NEWLINE code_block "end" _if: "if" bool_expr ":" _NEWLINE code_block ("else" ":" _NEWLINE code_block)? "end" !_some_namespace: "struct" | "namespace" diff --git a/src/starkware/cairo/lang/compiler/cairo_compile.py b/src/starkware/cairo/lang/compiler/cairo_compile.py index 7630d270..2ed6ca45 100644 --- a/src/starkware/cairo/lang/compiler/cairo_compile.py +++ b/src/starkware/cairo/lang/compiler/cairo_compile.py @@ -3,7 +3,7 @@ import os import sys import time -from typing import List, Sequence, Set, Tuple, Type, Union +from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME from starkware.cairo.lang.compiler.assembler import assemble @@ -12,16 +12,16 @@ from starkware.cairo.lang.compiler.identifier_manager import IdentifierError from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition from starkware.cairo.lang.compiler.module_reader import ModuleReader -from starkware.cairo.lang.compiler.preprocessor.preprocessor import Preprocessor, preprocess_codes +from starkware.cairo.lang.compiler.preprocessor.default_pass_manager import default_pass_manager +from starkware.cairo.lang.compiler.preprocessor.pass_manager import PassManager +from starkware.cairo.lang.compiler.preprocessor.preprocess_codes import preprocess_codes +from starkware.cairo.lang.compiler.preprocessor.preprocessor import PreprocessedProgram from starkware.cairo.lang.compiler.program import Program from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.version import __version__ -def main(): - start_time = time.time() - - parser = argparse.ArgumentParser(description='A tool to compile Cairo code.') +def cairo_compile_add_common_args(parser: argparse.ArgumentParser): parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}') parser.add_argument('files', metavar='file', type=str, nargs='+', help='File names') parser.add_argument( @@ -44,52 +44,71 @@ def main(): parser.add_argument( '--debug_info_with_source', action='store_true', help='Include debug information with a copy of the source code.') - parser.add_argument( - '--simple', action='store_true', - help='Compile the program without adding additional code. ' - 'In particular, program starts at the __start__ label, instead of the main() function.') parser.add_argument( '--cairo_dependencies', type=str, help='Output a list of the Cairo source files used during the compilation as a CMake file.') + parser.add_argument( + '--no_opt_unused_functions', dest='opt_unused_functions', action='store_false', + default=True, help='Disables unused function optimization.') - args = parser.parse_args() + +def cairo_compile_common( + args: argparse.Namespace, + pass_manager_factory: Callable[[argparse.Namespace, ModuleReader], PassManager], + assemble_func: Callable) -> PreprocessedProgram: + """ + Common code for CLI Cairo compilation. + + Arguments: + args - Parsed arguments. + pass_manager_factory - A pass manager factory. + assemble_func - a function that converts a preprocessed program to the final output, + the return value should be a Marshmallow dataclass. + """ + + start_time = time.time() debug_info = args.debug_info or args.debug_info_with_source - source_files = set() - erred: bool = False try: codes = get_codes(args.files) - if not args.simple: + file_contents_for_debug_info = {} + if getattr(args, 'proof_mode', False): codes = add_start_code(codes) + file_contents_for_debug_info[START_FILE_NAME] = codes[0][0] out = args.output if args.output is not None else sys.stdout - cairo_path = list(filter( + cairo_path: List[str] = list(filter( None, args.cairo_path.split(':') + os.getenv(LIBS_DIR_ENVVAR, '').split(':'))) + module_reader = get_module_reader(cairo_path=cairo_path) + + pass_manager = pass_manager_factory(args, module_reader) + + preprocessed = preprocess_codes( + codes=codes, + pass_manager=pass_manager, + main_scope=MAIN_SCOPE) if args.preprocess: - module_reader = get_module_reader(cairo_path) - preprocessed = preprocess_codes(codes, args.prime, module_reader.read, MAIN_SCOPE) - source_files = module_reader.source_files print(preprocessed.format(with_locations=debug_info), end='', file=out) else: - program, source_files = compile_cairo_extended( - codes, args.prime, cairo_path, debug_info, simple=args.simple) if args.debug_info_with_source: - for source_file in source_files | set(args.files): - program.debug_info.file_contents[source_file] = open(source_file).read() - json.dump(Program.Schema().dump(program), out, indent=4, sort_keys=True) - # Print a new line at the end. - print(file=out) + for source_file in module_reader.source_files | set(args.files): + file_contents_for_debug_info[source_file] = open(source_file).read() - except LocationError as err: - print(err, file=sys.stderr) - erred = True + assembled_program = assemble_func( + preprocessed, main_scope=MAIN_SCOPE, add_debug_info=debug_info, + file_contents_for_debug_info=file_contents_for_debug_info) - if args.cairo_dependencies: - generate_cairo_dependencies_file( - args.cairo_dependencies, source_files | set(args.files), start_time) + json.dump( + assembled_program.Schema().dump(assembled_program), out, indent=4, sort_keys=True) + # Print a new line at the end. + print(file=out) - return 1 if erred else 0 + return preprocessed + finally: + if args.cairo_dependencies: + generate_cairo_dependencies_file( + args.cairo_dependencies, module_reader.source_files | set(args.files), start_time) def get_module_reader(cairo_path: List[str]) -> ModuleReader: @@ -117,39 +136,28 @@ def add_start_code(codes_with_filenames: List[Tuple[str, str]]) -> List[Tuple[st def compile_cairo_files( - files: List[str], prime: int, + files: List[str], prime: Optional[int] = None, cairo_path: List[str] = [], debug_info: bool = False, - preprocessor_cls: Type[Preprocessor] = Preprocessor) -> Program: + pass_manager: Optional[PassManager] = None, + main_scope: Optional[ScopedName] = None) -> Program: """ Compiles a list of files (provided by their names). Note that cairo_path is ignored when reading the input files, it is only used when importing modules. """ - return compile_cairo(get_codes(files), prime, cairo_path, debug_info, preprocessor_cls) - - -def compile_cairo( - code: Union[str, Sequence[Tuple[str, str]]], prime: int, - cairo_path: List[str] = [], debug_info: bool = False, - preprocessor_cls: Type[Preprocessor] = Preprocessor, - add_start: bool = False) -> Program: - """ - Compiles a single code represented by a string, or a list codes. - The codes in the list are joined with file names, used for indicative - compilation errors. - """ - return compile_cairo_extended( - code=code, prime=prime, cairo_path=cairo_path, debug_info=debug_info, - preprocessor_cls=preprocessor_cls, add_start=add_start)[0] + return compile_cairo( + code=get_codes(files), prime=prime, cairo_path=cairo_path, debug_info=debug_info, + pass_manager=pass_manager, main_scope=main_scope) -def compile_cairo_extended( - code: Union[str, Sequence[Tuple[str, str]]], prime: int, +def compile_cairo_ex( + code: Union[str, Sequence[Tuple[str, str]]], prime: Optional[int] = None, cairo_path: List[str] = [], debug_info: bool = False, - preprocessor_cls: Type[Preprocessor] = Preprocessor, - add_start: bool = False, simple: bool = False) -> Tuple[Program, Set[str]]: + pass_manager: Optional[PassManager] = None, + add_start: bool = False, main_scope: Optional[ScopedName] = None) -> \ + Tuple[Program, PreprocessedProgram]: """ - Same as compile_cairo(), except that imported Cairo modules are returned. + Same as compile_cairo, but returns the preprocessed program as well. """ file_contents_for_debug_info = {} @@ -165,18 +173,42 @@ def compile_cairo_extended( if START_FILE_NAME == codes_with_filenames[0][1]: file_contents_for_debug_info[START_FILE_NAME] = codes_with_filenames[0][0] - module_reader = get_module_reader(cairo_path) - preprocessed_program = preprocess_codes( - codes_with_filenames, prime, module_reader.read, MAIN_SCOPE, preprocessor_cls) + if pass_manager is None: + assert prime is not None, 'Exactly one of prime and pass_manager must be given.' + module_reader = get_module_reader(cairo_path) + pass_manager = default_pass_manager(prime=prime, read_module=module_reader.read) + else: + assert prime is None, 'Exactly one of prime and pass_manager must be given.' + assert len(cairo_path) == 0, 'cairo_path cannot be specified where pass_manager is used.' - program = assemble( - preprocessed_program, main_scope=MAIN_SCOPE, add_debug_info=debug_info, + if main_scope is None: + main_scope = MAIN_SCOPE + preprocessed_program = preprocess_codes( + codes=codes_with_filenames, + pass_manager=pass_manager, + main_scope=main_scope) + program = cairo_assemble_program( + preprocessed_program, main_scope=main_scope, add_debug_info=debug_info, file_contents_for_debug_info=file_contents_for_debug_info) - if not simple: - check_main_args(program) + return program, preprocessed_program - return program, module_reader.source_files + +def compile_cairo( + code: Union[str, Sequence[Tuple[str, str]]], prime: Optional[int] = None, + cairo_path: List[str] = [], debug_info: bool = False, + pass_manager: Optional[PassManager] = None, + add_start: bool = False, main_scope: Optional[ScopedName] = None) -> Program: + """ + Compiles a single code represented by a string, or a list codes. + The codes in the list are joined with file names, used for indicative + compilation errors. + Returns the program. + """ + program, _ = compile_cairo_ex( + code=code, prime=prime, cairo_path=cairo_path, debug_info=debug_info, + pass_manager=pass_manager, add_start=add_start, main_scope=main_scope) + return program def check_main_args(program: Program): @@ -223,6 +255,7 @@ def get_start_code(): """ return """\ __start__: +ap += main.Args.SIZE + main.ImplicitArgs.SIZE call main __end__: @@ -253,5 +286,43 @@ def generate_cairo_dependencies_file(dependencies_path: str, files: Set[str], st os.utime(dependencies_path, (start_time, start_time)) +def cairo_assemble_program( + preprocessed_program: PreprocessedProgram, main_scope: ScopedName, + add_debug_info: bool, file_contents_for_debug_info: Dict[str, str]) -> Program: + program = assemble( + preprocessed_program, main_scope=MAIN_SCOPE, add_debug_info=add_debug_info, + file_contents_for_debug_info=file_contents_for_debug_info) + check_main_args(program) + return program + + +def main(): + parser = argparse.ArgumentParser(description='A tool to compile Cairo code.') + parser.add_argument( + '--proof_mode', action='store_true', default=False, + help='Add instructions to call main() at the beginning of the program. This should be used ' + 'if the program is proven directly (without the bootloader).') + parser.add_argument( + '--no_proof_mode', dest='proof_mode', action='store_false', + help='Disable proof mode (see --proof_mode).') + + def pass_manager_factory(args: argparse.Namespace, module_reader: ModuleReader) -> PassManager: + return default_pass_manager( + prime=args.prime, + read_module=module_reader.read, + opt_unused_functions=args.opt_unused_functions) + + try: + cairo_compile_add_common_args(parser) + args = parser.parse_args() + cairo_compile_common( + args=args, pass_manager_factory=pass_manager_factory, + assemble_func=cairo_assemble_program) + except LocationError as err: + print(err, file=sys.stderr) + return 1 + return 0 + + if __name__ == '__main__': sys.exit(main()) diff --git a/src/starkware/cairo/lang/compiler/conftest.py b/src/starkware/cairo/lang/compiler/conftest.py new file mode 100644 index 00000000..5e0ea869 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/conftest.py @@ -0,0 +1,6 @@ +import pytest + +# Instruct pytest to print full information (e.g., the values on both sides of the equality) +# about asserts that failed in the module below. +# Normally, pytest prints full information only for test files (according to their name). +pytest.register_assert_rewrite('starkware.cairo.lang.compiler.parser_test_utils') diff --git a/src/starkware/cairo/lang/compiler/const_expr_checker.py b/src/starkware/cairo/lang/compiler/const_expr_checker.py index 8babe63f..40148f03 100644 --- a/src/starkware/cairo/lang/compiler/const_expr_checker.py +++ b/src/starkware/cairo/lang/compiler/const_expr_checker.py @@ -1,11 +1,11 @@ from starkware.cairo.lang.compiler.ast.expr import ( - ExprAddressOf, ExprCast, ExprConst, ExprDeref, ExprFutureLabel, ExprNeg, ExprOperator, - ExprParentheses, ExprPyConst, ExprReg) + ExprConst, ExprDeref, ExprFutureLabel, ExprNeg, ExprOperator, ExprPyConst, ExprReg) class ConstExprChecker: """ A visitor class to check whether an expression contains only numeric and symbolic constants. + This visitor needs to visit only expressions returned by the type system visitor. """ def visit(self, obj): @@ -29,18 +29,9 @@ def visit_ExprOperator(self, expr: ExprOperator): def visit_ExprNeg(self, expr: ExprNeg): return self.visit(expr.val) - def visit_ExprParentheses(self, expr: ExprParentheses): - return self.visit(expr.val) - def visit_ExprDeref(self, expr: ExprDeref): return False - def visit_ExprAddressOf(self, expr: ExprAddressOf): - return False - - def visit_ExprCast(self, expr: ExprCast): - return self.visit(expr) - def is_const_expr(expr): """ diff --git a/src/starkware/cairo/lang/compiler/expression_evaluator.py b/src/starkware/cairo/lang/compiler/expression_evaluator.py index 35141fe4..6e2c9e7b 100644 --- a/src/starkware/cairo/lang/compiler/expression_evaluator.py +++ b/src/starkware/cairo/lang/compiler/expression_evaluator.py @@ -1,9 +1,10 @@ -from typing import Dict, Optional +from typing import MutableMapping, Optional from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypePointer from starkware.cairo.lang.compiler.ast.expr import ExprConst, ExprDeref, Expression, ExprReg from starkware.cairo.lang.compiler.error_handling import LocationError from starkware.cairo.lang.compiler.expression_simplifier import ExpressionSimplifier +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager from starkware.cairo.lang.compiler.instruction import Register from starkware.cairo.lang.compiler.type_system_visitor import simplify_type_system @@ -15,15 +16,18 @@ class ExpressionEvaluatorError(LocationError): class ExpressionEvaluator(ExpressionSimplifier): prime: int - def __init__(self, prime: int, ap: Optional[int], fp: int, memory: Dict[int, int]): + def __init__( + self, prime: int, ap: Optional[int], fp: int, memory: MutableMapping[int, int], + identifiers: Optional[IdentifierManager] = None): super().__init__(prime=prime) assert self.prime is not None self.ap = ap self.fp = fp self.memory = memory + self.identifiers = identifiers def eval(self, expr: Expression) -> int: - expr, expr_type = simplify_type_system(expr) + expr, expr_type = simplify_type_system(expr, identifiers=self.identifiers) assert isinstance(expr_type, (TypeFelt, TypePointer)), \ f"Unable to evaluate expression of type '{expr_type.format()}'." res = self.visit(expr) diff --git a/src/starkware/cairo/lang/compiler/expression_transformer.py b/src/starkware/cairo/lang/compiler/expression_transformer.py index 50bc6974..d15c9a92 100644 --- a/src/starkware/cairo/lang/compiler/expression_transformer.py +++ b/src/starkware/cairo/lang/compiler/expression_transformer.py @@ -1,8 +1,11 @@ from typing import Optional from starkware.cairo.lang.compiler.ast.expr import ( - ExprAddressOf, ExprCast, ExprConst, ExprDeref, Expression, ExprFutureLabel, ExprIdentifier, - ExprNeg, ExprOperator, ExprParentheses, ExprPyConst, ExprReg, ExprTuple) + ArgList, ExprAddressOf, ExprAssignment, ExprCast, ExprConst, ExprDeref, ExprDot, Expression, + ExprFutureLabel, ExprIdentifier, ExprNeg, ExprOperator, ExprParentheses, ExprPyConst, ExprReg, + ExprSubscript, ExprTuple) +from starkware.cairo.lang.compiler.ast.expr_func_call import ExprFuncCall +from starkware.cairo.lang.compiler.ast.rvalue import RvalueFuncCall from starkware.cairo.lang.compiler.error_handling import Location, LocationError @@ -59,6 +62,20 @@ def visit_ExprParentheses(self, expr: ExprParentheses): def visit_ExprDeref(self, expr: ExprDeref): return ExprDeref(addr=self.visit(expr.addr), location=self.location_modifier(expr.location)) + def visit_ExprSubscript(self, expr: ExprSubscript): + return ExprSubscript( + expr=self.visit(expr.expr), + offset=self.visit(expr.offset), + location=self.location_modifier(expr.location)) + + def visit_ExprDot(self, expr: ExprDot): + return ExprDot( + expr=self.visit(expr.expr), + # Avoid visiting 'member' with an overridden visit_ExprIdentifier, as it is not a + # proper identifier. + member=ExpressionTransformer.visit_ExprIdentifier(self, expr.member), + location=self.location_modifier(expr.location)) + def visit_ExprAddressOf(self, expr: ExprAddressOf): inner_expr = self.visit(expr.expr) return ExprAddressOf( @@ -72,8 +89,36 @@ def visit_ExprCast(self, expr: ExprCast): cast_type=expr.cast_type, location=self.location_modifier(expr.location)) + def visit_ArgList(self, arg_list: ArgList): + return ArgList( + args=[ + ExprAssignment( + identifier=item.identifier, + expr=self.visit(item.expr), + location=self.location_modifier(item.location)) + for item in arg_list.args + ], + notes=arg_list.notes, + has_trailing_comma=arg_list.has_trailing_comma, + location=self.location_modifier(arg_list.location)) + def visit_ExprTuple(self, expr: ExprTuple): - raise ExpressionTransformerError('Tuples are not supported yet.', location=expr.location) + return ExprTuple( + members=self.visit_ArgList(expr.members), + location=self.location_modifier(expr.location)) + + def visit_RvalueFuncCall(self, rvalue: RvalueFuncCall): + return RvalueFuncCall( + func_ident=self.visit(rvalue.func_ident), + arguments=self.visit_ArgList(rvalue.arguments), + implicit_arguments=None if rvalue.implicit_arguments is None else self.visit_ArgList( + rvalue.implicit_arguments), + location=self.location_modifier(rvalue.location)) + + def visit_ExprFuncCall(self, expr: ExprFuncCall): + return ExprFuncCall( + rvalue=self.visit_RvalueFuncCall(expr.rvalue), + location=self.location_modifier(expr.location)) def location_modifier(self, location: Optional[Location]) -> Optional[Location]: """ diff --git a/src/starkware/cairo/lang/compiler/fields.py b/src/starkware/cairo/lang/compiler/fields.py index f57148bd..2ff61b27 100644 --- a/src/starkware/cairo/lang/compiler/fields.py +++ b/src/starkware/cairo/lang/compiler/fields.py @@ -2,7 +2,7 @@ from starkware.cairo.lang.compiler.ast.cairo_types import CairoType from starkware.cairo.lang.compiler.parser import parse_expr, parse_type -from starkware.cairo.lang.compiler.type_system_visitor import ( +from starkware.cairo.lang.compiler.type_system import ( is_type_resolved, mark_type_resolved, mark_types_in_expr_resolved) diff --git a/src/starkware/cairo/lang/compiler/identifier_definition.py b/src/starkware/cairo/lang/compiler/identifier_definition.py index a0882501..2ab4449b 100644 --- a/src/starkware/cairo/lang/compiler/identifier_definition.py +++ b/src/starkware/cairo/lang/compiler/identifier_definition.py @@ -104,6 +104,14 @@ class LabelDefinition(IdentifierDefinition): pc: int +@marshmallow_dataclass.dataclass +class FunctionDefinition(LabelDefinition): + TYPE: ClassVar[str] = 'function' + Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema + + decorators: List[str] + + @marshmallow_dataclass.dataclass class ReferenceDefinition(IdentifierDefinition): TYPE: ClassVar[str] = 'reference' @@ -144,6 +152,7 @@ class IdentifierDefinitionSchema(OneOfSchema): ConstDefinition.TYPE: ConstDefinition.Schema, MemberDefinition.TYPE: MemberDefinition.Schema, LabelDefinition.TYPE: LabelDefinition.Schema, + FunctionDefinition.TYPE: FunctionDefinition.Schema, ReferenceDefinition.TYPE: ReferenceDefinition.Schema, ScopeDefinition.TYPE: ScopeDefinition.Schema, StructDefinition.TYPE: StructDefinition.Schema, diff --git a/src/starkware/cairo/lang/compiler/identifier_manager.py b/src/starkware/cairo/lang/compiler/identifier_manager.py index d604f675..59ad8940 100644 --- a/src/starkware/cairo/lang/compiler/identifier_manager.py +++ b/src/starkware/cairo/lang/compiler/identifier_manager.py @@ -1,8 +1,8 @@ import dataclasses -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Set, Union from starkware.cairo.lang.compiler.identifier_definition import ( - AliasDefinition, IdentifierDefinition) + AliasDefinition, FutureIdentifierDefinition, IdentifierDefinition) from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -234,6 +234,34 @@ def exclude(self, other: 'IdentifierManager') -> 'IdentifierManager': if name not in other_as_dict }) + def prune(self, prefixes_to_prune: Set[ScopedName]): + """ + Removes identifiers that have one of the given prefixes. + """ + # Prune dict. + new_dict = {} + for name, value in self.dict.items(): + parent = name + while len(parent.path) > 0: + if parent in prefixes_to_prune: + break + parent = parent[:-1] + if parent in prefixes_to_prune: + assert isinstance(value, (IdentifierDefinition, FutureIdentifierDefinition)), \ + f"Attempted to prune identifier '{value}'" \ + f" of unprunable type '{type(value).__name__}'." + continue + new_dict[name] = value + self.dict = new_dict + + # Remove scopes. + for prefix in prefixes_to_prune: + assert len(prefix.path) > 0 + current = self.root + for element in prefix[:-1].path: + current = current.subscopes[element] + del current.subscopes[prefix.path[-1]] + class IdentifierScope: """ diff --git a/src/starkware/cairo/lang/compiler/identifier_manager_test.py b/src/starkware/cairo/lang/compiler/identifier_manager_test.py index bae36571..525b2e36 100644 --- a/src/starkware/cairo/lang/compiler/identifier_manager_test.py +++ b/src/starkware/cairo/lang/compiler/identifier_manager_test.py @@ -136,7 +136,7 @@ def test_identifier_manager_search(): ([''], 'x', 'x'), (['a', 'e', 'a.b.c'], 'b.z', 'a.b.z'), ]: - result = manager.search(list(map(scope, accessible_scopes)), name) + result = manager.search(list(map(scope, accessible_scopes)), scope(name)) assert result.canonical_name == scope(canonical_name) assert result.identifier_definition == identifier_dict[scope(canonical_name)] diff --git a/src/starkware/cairo/lang/compiler/identifier_utils.py b/src/starkware/cairo/lang/compiler/identifier_utils.py index dd44bb96..7900318c 100644 --- a/src/starkware/cairo/lang/compiler/identifier_utils.py +++ b/src/starkware/cairo/lang/compiler/identifier_utils.py @@ -1,11 +1,8 @@ from typing import Dict -from starkware.cairo.lang.compiler.constants import SIZE_CONSTANT -from starkware.cairo.lang.compiler.identifier_definition import ( - ConstDefinition, DefinitionError, IdentifierDefinition, ReferenceDefinition, StructDefinition) +from starkware.cairo.lang.compiler.identifier_definition import DefinitionError, StructDefinition from starkware.cairo.lang.compiler.identifier_manager import ( - IdentifierError, IdentifierManager, IdentifierSearchResult, MissingIdentifierError) -from starkware.cairo.lang.compiler.offset_reference import OffsetReferenceDefinition + IdentifierManager, MissingIdentifierError) from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -40,43 +37,3 @@ def get_struct_member_offsets( return { name: member_def.offset for name, member_def in struct_def.members.items() } - - -def resolve_search_result( - search_result: IdentifierSearchResult, - identifiers: IdentifierManager) -> IdentifierDefinition: - """ - Returns a fully parsed identifier definition for the given identifier search result. - If search_result contains a reference with non_parsed data, returns an instance of - OffsetReferenceDefinition. - """ - identifier_definition = search_result.identifier_definition - - if len(search_result.non_parsed) == 0: - return identifier_definition - - if isinstance(identifier_definition, StructDefinition): - if search_result.non_parsed == SIZE_CONSTANT: - return ConstDefinition(value=identifier_definition.size) - - member_def = identifier_definition.members.get(search_result.non_parsed.path[0]) - struct_name = identifier_definition.full_name - if member_def is None: - raise DefinitionError( - f"'{search_result.non_parsed}' is not a member of '{struct_name}'.") - - if len(search_result.non_parsed) > 1: - raise IdentifierError( - f"Unexpected '.' after '{struct_name + search_result.non_parsed.path[0]}' which is " - f'{member_def.TYPE}.') - - identifier_definition = member_def - elif isinstance(identifier_definition, ReferenceDefinition): - identifier_definition = OffsetReferenceDefinition( - parent=identifier_definition, - identifiers=identifiers, - member_path=search_result.non_parsed) - else: - search_result.assert_fully_parsed() - - return identifier_definition diff --git a/src/starkware/cairo/lang/compiler/identifier_utils_test.py b/src/starkware/cairo/lang/compiler/identifier_utils_test.py index 240b92ae..23d3ce8f 100644 --- a/src/starkware/cairo/lang/compiler/identifier_utils_test.py +++ b/src/starkware/cairo/lang/compiler/identifier_utils_test.py @@ -6,9 +6,8 @@ from starkware.cairo.lang.compiler.identifier_definition import ( ConstDefinition, DefinitionError, MemberDefinition, StructDefinition) from starkware.cairo.lang.compiler.identifier_manager import ( - IdentifierError, IdentifierManager, IdentifierSearchResult, MissingIdentifierError) -from starkware.cairo.lang.compiler.identifier_utils import ( - get_struct_definition, resolve_search_result) + IdentifierManager, MissingIdentifierError) +from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition from starkware.cairo.lang.compiler.scoped_name import ScopedName scope = ScopedName.from_string @@ -45,29 +44,3 @@ def test_get_struct_definition(): with pytest.raises( MissingIdentifierError, match=re.escape("Unknown identifier 'abc'.")): get_struct_definition(scope('abc'), manager) - - -def test_resolve_search_result(): - struct_def = StructDefinition( - full_name=scope('T'), - members={ - 'a': MemberDefinition(offset=0, cairo_type=TypeFelt()), - - 'b': MemberDefinition(offset=1, cairo_type=TypeFelt()), - }, - size=2, - ) - - identifier_dict = { - struct_def.full_name: struct_def, - } - - identifier = IdentifierManager.from_dict(identifier_dict) - - with pytest.raises(IdentifierError, match="Unexpected '.' after 'T.a' which is member"): - resolve_search_result( - search_result=IdentifierSearchResult( - identifier_definition=struct_def, - canonical_name=struct_def.full_name, - non_parsed=scope('a.z')), - identifiers=identifier) diff --git a/src/starkware/cairo/lang/compiler/import_loader.py b/src/starkware/cairo/lang/compiler/import_loader.py index b34e4608..124a58ff 100644 --- a/src/starkware/cairo/lang/compiler/import_loader.py +++ b/src/starkware/cairo/lang/compiler/import_loader.py @@ -3,7 +3,7 @@ from starkware.cairo.lang.compiler.ast.code_elements import ( CodeBlock, CodeElement, CodeElementFunction, CodeElementImport) from starkware.cairo.lang.compiler.ast.module import CairoFile -from starkware.cairo.lang.compiler.ast.visitor import Visitor +from starkware.cairo.lang.compiler.ast.visitor import Visitor, get_lang_from_file from starkware.cairo.lang.compiler.error_handling import Location, LocationError from starkware.cairo.lang.compiler.module_reader import ModuleNotFoundException from starkware.cairo.lang.compiler.parser import parse_file @@ -52,6 +52,7 @@ class ImportsCollector: def __init__(self, read_file: Callable[[str], Tuple[str, str]]): self.curr_ancestors: List[str] = [] self.collected_data: Dict[str, CairoFile] = {} + self.lang: Dict[str, Optional[str]] = {} self.read_file = read_file def collect(self, curr_pkg_name: str, location: Optional[Location] = None): @@ -74,6 +75,8 @@ def collect(self, curr_pkg_name: str, location: Optional[Location] = None): parsed_file: CairoFile = parse_file(code, filename=filename) + lang = get_lang_from_file(parsed_file) + # Get current file dependencies. collector = DirectDependenciesCollector() collector.get_using_pkgs_in_block(parsed_file.code_block) @@ -84,10 +87,16 @@ def collect(self, curr_pkg_name: str, location: Optional[Location] = None): # Collect ASTs recursively. for pkg_name, location in collector.packages: self.collect(pkg_name, location=location) + if not (self.lang[pkg_name] is None or self.lang[pkg_name] == lang): + raise ImportLoaderError( + f"Importing modules with %lang directive '{self.lang[pkg_name]}' must " + 'be from a module with the same directive.', + location=location) # Pop current package from ancestors list after scanning its dependencies. self.curr_ancestors.pop() self.collected_data[curr_pkg_name] = parsed_file + self.lang[curr_pkg_name] = lang class DirectDependenciesCollector(Visitor): diff --git a/src/starkware/cairo/lang/compiler/import_loader_test.py b/src/starkware/cairo/lang/compiler/import_loader_test.py index 0f8b3034..1febc599 100644 --- a/src/starkware/cairo/lang/compiler/import_loader_test.py +++ b/src/starkware/cairo/lang/compiler/import_loader_test.py @@ -1,9 +1,10 @@ +import re from random import sample from typing import Dict import pytest -from starkware.cairo.lang.compiler.error_handling import get_location_marks +from starkware.cairo.lang.compiler.error_handling import LocationError, get_location_marks from starkware.cairo.lang.compiler.import_loader import ( DirectDependenciesCollector, ImportLoaderError, UsingCycleError, collect_imports) from starkware.cairo.lang.compiler.parser import ParserError, parse_file @@ -178,3 +179,61 @@ def test_circular_dep(): a7 imports a8 imports a0""" + + +def test_lang_directive(): + files = { + 'a': """ +from c import x +""", + 'b': """ +%lang other_lang +from c import x +""", + 'c': """ +%lang lang +from d_lang import x +from d_no_lang import x +""", + 'd_lang': """ +%lang lang +const x = 0 +""", + 'd_no_lang': """ +const x = 0 +""", + 'e': """ +%lang lang # First line. +%lang lang # Second line. +"""} + + # Make sure that starting from 'c' does not raise an exception. + collect_imports('c', read_file_from_dict(files)) + + verify_exception(files, 'a', """ +a:?:?: Importing modules with %lang directive 'lang' must be from a module with the same directive. +from c import x + ^ +""") + + verify_exception(files, 'b', """ +b:?:?: Importing modules with %lang directive 'lang' must be from a module with the same directive. +from c import x + ^ +""") + + verify_exception(files, 'e', """ +e:?:?: Found two %lang directives +%lang lang # Second line. +^********^ +""") + + +def verify_exception(files: Dict[str, str], main_file: str, error: str): + """ + Verifies that parsing the code results in the given error. + """ + with pytest.raises(LocationError) as e: + collect_imports(main_file, read_file_from_dict(files)) + # Remove line and column information from the error using a regular expression. + assert re.sub(':[0-9]+:[0-9]+: ', ':?:?: ', str(e.value)) == error.strip() diff --git a/src/starkware/cairo/lang/compiler/instruction_test.py b/src/starkware/cairo/lang/compiler/instruction_test.py index 5bfd2a54..4b136a35 100644 --- a/src/starkware/cairo/lang/compiler/instruction_test.py +++ b/src/starkware/cairo/lang/compiler/instruction_test.py @@ -1,4 +1,4 @@ -from random import randint +from random import randrange import pytest @@ -7,8 +7,8 @@ def test_decode(): - offsets = [randint(0, 2**OFFSET_BITS) for _ in range(3)] - flags = randint(0, 2**N_FLAGS) + offsets = [randrange(0, 2**OFFSET_BITS) for _ in range(3)] + flags = randrange(0, 2**N_FLAGS) instruction = 0 for part in [flags] + offsets[::-1]: instruction = (instruction << OFFSET_BITS) | part diff --git a/src/starkware/cairo/lang/compiler/offset_reference.py b/src/starkware/cairo/lang/compiler/offset_reference.py index 5202a6b4..1dd86802 100644 --- a/src/starkware/cairo/lang/compiler/offset_reference.py +++ b/src/starkware/cairo/lang/compiler/offset_reference.py @@ -3,17 +3,12 @@ import marshmallow -from starkware.cairo.lang.compiler.ast.cairo_types import TypePointer, TypeStruct -from starkware.cairo.lang.compiler.ast.expr import ( - ExprAddressOf, ExprCast, ExprConst, ExprDeref, Expression, ExprOperator) +from starkware.cairo.lang.compiler.ast.expr import ExprDot, Expression, ExprIdentifier from starkware.cairo.lang.compiler.identifier_definition import ( - DefinitionError, IdentifierDefinition, ReferenceDefinition, StructDefinition) -from starkware.cairo.lang.compiler.identifier_manager import ( - IdentifierManager, MissingIdentifierError) + IdentifierDefinition, ReferenceDefinition) from starkware.cairo.lang.compiler.preprocessor.flow import ( FlowTrackingData, FlowTrackingDataActual, ReferenceManager) from starkware.cairo.lang.compiler.scoped_name import ScopedName -from starkware.cairo.lang.compiler.type_system_visitor import simplify_type_system @dataclasses.dataclass @@ -30,7 +25,6 @@ class OffsetReferenceDefinition(IdentifierDefinition): Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema parent: ReferenceDefinition - identifiers: IdentifierManager member_path: ScopedName def eval( @@ -41,38 +35,9 @@ def eval( name=self.parent.full_name) assert isinstance(flow_tracking_data, FlowTrackingDataActual), \ 'Resolved references can only come from FlowTrackingDataActual.' - expr, expr_type = simplify_type_system(reference.eval(flow_tracking_data.ap_tracking)) - for member_name in self.member_path.path: - if isinstance(expr_type, TypeStruct): - expr_type = expr_type.get_pointer_type() - # In this case, take the address of the reference value. - to_addr = lambda expr: ExprAddressOf(expr=expr) - else: - to_addr = lambda expr: expr - - if not isinstance(expr_type, TypePointer) or \ - not isinstance(expr_type.pointee, TypeStruct): - raise DefinitionError('Member access requires a type of the form Struct*.') - - struct_name = expr_type.pointee.resolved_scope - struct_def = self.identifiers.get_by_full_name(name=struct_name) - if struct_def is None: - raise MissingIdentifierError(struct_name) + expr = reference.eval(flow_tracking_data.ap_tracking) - if not isinstance(struct_def, StructDefinition): - raise DefinitionError(f"""\ -Expected '{struct_name}' to be a {StructDefinition.TYPE}. Found: '{struct_def.TYPE}'.""") - - member_definition = struct_def.members.get(member_name) - if member_definition is None: - raise DefinitionError( - f"'{member_name}' is not a member of '{struct_def.full_name}'.") - offset_value = member_definition.offset - expr_type = member_definition.cairo_type - - expr = ExprDeref(addr=ExprOperator(a=to_addr(expr), op='+', b=ExprConst(offset_value))) + for member_name in self.member_path.path: + expr = ExprDot(expr=expr, member=ExprIdentifier(name=member_name)) - return ExprCast( - expr=expr, - dest_type=expr_type, - ) + return expr diff --git a/src/starkware/cairo/lang/compiler/offset_reference_test.py b/src/starkware/cairo/lang/compiler/offset_reference_test.py index a74f65ce..d863e5d2 100644 --- a/src/starkware/cairo/lang/compiler/offset_reference_test.py +++ b/src/starkware/cairo/lang/compiler/offset_reference_test.py @@ -1,16 +1,12 @@ -import pytest - -from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypePointer, TypeStruct -from starkware.cairo.lang.compiler.identifier_definition import ( - DefinitionError, MemberDefinition, ReferenceDefinition, StructDefinition) -from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager +from starkware.cairo.lang.compiler.ast.cairo_types import TypePointer, TypeStruct +from starkware.cairo.lang.compiler.identifier_definition import ReferenceDefinition from starkware.cairo.lang.compiler.offset_reference import OffsetReferenceDefinition from starkware.cairo.lang.compiler.parser import parse_expr from starkware.cairo.lang.compiler.preprocessor.flow import ( FlowTrackingDataActual, ReferenceManager, RegTrackingData) from starkware.cairo.lang.compiler.references import Reference from starkware.cairo.lang.compiler.scoped_name import ScopedName -from starkware.cairo.lang.compiler.type_system_visitor import mark_types_in_expr_resolved +from starkware.cairo.lang.compiler.type_system import mark_types_in_expr_resolved scope = ScopedName.from_string @@ -18,25 +14,8 @@ def test_offset_reference_definition_typed_members(): t = TypeStruct(scope=scope('T'), is_fully_resolved=True) t_star = TypePointer(pointee=t) - s_star = TypePointer(pointee=TypeStruct(scope=scope('S'), is_fully_resolved=True)) reference_manager = ReferenceManager() - identifiers = IdentifierManager.from_dict({ - scope('T'): StructDefinition( - full_name='T', - members={ - 'x': MemberDefinition(offset=3, cairo_type=s_star), - 'flt': MemberDefinition(offset=4, cairo_type=TypeFelt()), - }, - size=5, - ), - scope('S'): StructDefinition( - full_name='S', - members={ - 'x': MemberDefinition(offset=10, cairo_type=t), - }, - size=15, - ), - }) + main_reference = ReferenceDefinition(full_name=scope('a'), cairo_type=t_star, references=[]) references = { scope('a'): reference_manager.alloc_id(Reference( @@ -51,25 +30,9 @@ def test_offset_reference_definition_typed_members(): reference_ids=references, ) - # Create OffsetReferenceDefinition instances for expressions of the form "a.", - # such as a.x and a.x.x, and check the result of evaluation those expressions. - for member_path, expected_result in [ - ('x', 'cast([ap - 1 + 3], S*)'), - ('x.x', 'cast([[ap - 1 + 3] + 10], T)'), - ('x.x.x', 'cast([&[[ap - 1 + 3] + 10] + 3], S*)'), - ('x.x.flt', 'cast([&[[ap - 1 + 3] + 10] + 4], felt)')]: - definition = OffsetReferenceDefinition( - parent=main_reference, identifiers=identifiers, member_path=scope(member_path)) - definition.eval( - reference_manager=reference_manager, - flow_tracking_data=flow_tracking_data).format() == expected_result - - definition = OffsetReferenceDefinition( - parent=main_reference, identifiers=identifiers, member_path=scope('x.x.flt.x')) - with pytest.raises(DefinitionError, match='Member access requires a type of the form Struct*.'): - definition.eval(reference_manager=reference_manager, flow_tracking_data=flow_tracking_data) - - definition = OffsetReferenceDefinition( - parent=main_reference, identifiers=identifiers, member_path=scope('x.y')) - with pytest.raises(DefinitionError, match="'y' is not a member of 'S'."): - definition.eval(reference_manager=reference_manager, flow_tracking_data=flow_tracking_data) + # Create OffsetReferenceDefinition instance for an expression of the form "a.", + # in this case a.x.y.z, and check the result of evaluation of this expression. + definition = OffsetReferenceDefinition(parent=main_reference, member_path=scope('x.y.z')) + assert definition.eval( + reference_manager=reference_manager, + flow_tracking_data=flow_tracking_data).format() == 'cast(ap - 1, T*).x.y.z' diff --git a/src/starkware/cairo/lang/compiler/parser.py b/src/starkware/cairo/lang/compiler/parser.py index 041301c7..aa9ae8b0 100644 --- a/src/starkware/cairo/lang/compiler/parser.py +++ b/src/starkware/cairo/lang/compiler/parser.py @@ -51,6 +51,7 @@ def wrap_lark_error(err: LarkError, input_file: InputFile) -> Exception: expected.remove('_NEWLINE') TOKENS = { '_ARROW': '"->"', + '_AT': '"@"', '_DBL_EQ': '"=="', '_DBL_PLUS': '"++"', '_NEQ': '"!="', @@ -60,6 +61,7 @@ def wrap_lark_error(err: LarkError, input_file: InputFile) -> Exception: 'COLON': '":"', 'DOT': '"."', 'EQUAL': '"="', + 'FUNC': 'func', 'IDENTIFIER': 'identifier', 'INT': 'integer', 'LBRACE': '"{"', diff --git a/src/starkware/cairo/lang/compiler/parser_errors_test.py b/src/starkware/cairo/lang/compiler/parser_errors_test.py index 5ffbc102..d716f774 100644 --- a/src/starkware/cairo/lang/compiler/parser_errors_test.py +++ b/src/starkware/cairo/lang/compiler/parser_errors_test.py @@ -1,8 +1,7 @@ -import re - import pytest from starkware.cairo.lang.compiler.parser import ParserError, parse_file +from starkware.cairo.lang.compiler.parser_test_utils import verify_exception def test_unexpected_token(): @@ -23,15 +22,16 @@ def test_unexpected_token(): verify_exception(""" foo bar """, """ -file:?:?: Unexpected token Token(IDENTIFIER, 'bar'). Expected one of: "(", ".", ":", "=", "{", \ -operator. +file:?:?: Unexpected token Token(IDENTIFIER, 'bar'). Expected one of: "(", ".", ":", "=", "[", \ +"{", operator. foo bar ^*^ """) verify_exception(""" foo = bar test """, """ -file:?:?: Unexpected token Token(IDENTIFIER, 'test'). Expected one of: ".", ";", operator. +file:?:?: Unexpected token Token(IDENTIFIER, 'test'). Expected one of: "(", ".", ";", "[", "{", \ +operator. foo = bar test ^**^ """) @@ -45,21 +45,22 @@ def test_unexpected_token(): verify_exception(""" %[ 5 %] %[ 7 %] """, """ -file:?:?: Unexpected token Token(PYCONST, '%[ 7 %]'). Expected one of: "=", operator. +file:?:?: Unexpected token Token(PYCONST, '%[ 7 %]'). Expected one of: ".", "=", "[", operator. %[ 5 %] %[ 7 %] ^*****^ """) verify_exception(""" static_assert ap """, r""" -file:?:?: Unexpected token Token(_NEWLINE, '\n'). Expected one of: "==", operator. +file:?:?: Unexpected token Token(_NEWLINE, '\n'). Expected one of: ".", "==", "[", operator. static_assert ap ^ """) verify_exception(""" [ap] = x& + y """, """ -file:?:?: Unexpected token Token(AMPERSAND, '&'). Expected one of: ".", ";", operator. +file:?:?: Unexpected token Token(AMPERSAND, '&'). Expected one of: "(", ".", ";", "[", "{", \ +operator. [ap] = x& + y ^ """) @@ -87,7 +88,8 @@ def test_unexpected_token(): verify_exception(""" if x y """, """ -file:?:?: Unexpected token Token(IDENTIFIER, 'y'). Expected one of: "!=", ".", "==", operator. +file:?:?: Unexpected token Token(IDENTIFIER, 'y'). Expected one of: "!=", "(", ".", "==", "[", \ +"{", operator. if x y ^ """) @@ -109,10 +111,10 @@ def test_unexpected_token(): def test_unexpected_character(): verify_exception(""" -x@y +x~y """, """ -file:?:?: Unexpected character "@". -x@y +file:?:?: Unexpected character "~". +x~y ^ """) @@ -127,13 +129,3 @@ def test_parser_error(): assert str(e.value).endswith(""" const a = 5 ^""") - - -def verify_exception(code: str, error: str): - """ - Verifies that parsing the code results in the given error. - """ - with pytest.raises(ParserError) as e: - parse_file(code, '') - # Remove line and column information from the error using a regular expression. - assert re.sub(':[0-9]+:[0-9]+: ', 'file:?:?: ', str(e.value)) == error.strip() diff --git a/src/starkware/cairo/lang/compiler/parser_test.py b/src/starkware/cairo/lang/compiler/parser_test.py index 0007933e..df5ebb13 100644 --- a/src/starkware/cairo/lang/compiler/parser_test.py +++ b/src/starkware/cairo/lang/compiler/parser_test.py @@ -2,9 +2,11 @@ from starkware.cairo.lang.compiler.ast.aliased_identifier import AliasedIdentifier from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypeTuple -from starkware.cairo.lang.compiler.ast.code_elements import CodeElementImport +from starkware.cairo.lang.compiler.ast.code_elements import ( + CodeElementImport, CodeElementReference, CodeElementReturnValueReference) from starkware.cairo.lang.compiler.ast.expr import ( - ExprConst, ExprDeref, ExprIdentifier, ExprOperator, ExprPyConst, ExprReg) + ExprConst, ExprDeref, ExprDot, ExprIdentifier, ExprOperator, ExprParentheses, ExprPyConst, + ExprReg, ExprSubscript) from starkware.cairo.lang.compiler.ast.formatting_utils import FormattingError from starkware.cairo.lang.compiler.ast.instructions import ( AddApInstruction, AssertEqInstruction, CallInstruction, CallLabelInstruction, InstructionAst, @@ -15,7 +17,9 @@ from starkware.cairo.lang.compiler.instruction import Register from starkware.cairo.lang.compiler.parser import ( parse, parse_code_element, parse_expr, parse_instruction, parse_type) +from starkware.cairo.lang.compiler.parser_test_utils import verify_exception from starkware.cairo.lang.compiler.parser_transformer import ParserError +from starkware.python.utils import safe_zip def test_types(): @@ -30,12 +34,30 @@ def test_type_tuple(): assert parse_type('( felt, felt* , (felt, T.S,)* )').format() == '(felt, felt*, (felt, T.S)*)' -def test_identifier(): +def test_identifier_and_dot(): assert parse_expr('x.y . z + x ').format() == 'x.y.z + x' + assert parse_expr(' [x]. y . z').format() == '[x].y.z' + assert parse_expr('(x-y).z').format() == '(x - y).z' + assert parse_expr('x-y.z').format() == 'x - y.z' + assert parse_expr('[ap+1].x.y').format() == '[ap + 1].x.y' + assert parse_expr('((a.b + c).d * e.f + g.h).i.j').format() == '((a.b + c).d * e.f + g.h).i.j' + + assert parse_expr('(x).y.z') == \ + ExprDot( + expr=ExprDot( + expr=ExprParentheses(val=ExprIdentifier(name='x')), + member=ExprIdentifier(name='y')), + member=ExprIdentifier(name='z')) + assert parse_expr('x.y.z') == ExprIdentifier(name='x.y.z') + with pytest.raises(ParserError): parse_expr('.x') with pytest.raises(ParserError): parse_expr('x.') + with pytest.raises(ParserError): + parse_expr('x.(y+z)') + with pytest.raises(ParserError): + parse_expr('x.[a]') def test_typed_identifier(): @@ -86,6 +108,29 @@ def test_deref_expr(): assert expr.format() == '[[fp - 7] + 3]' +def test_subscript_expr(): + assert parse_expr('x[y]').format() == 'x[y]' + assert parse_expr('[x][y][z][w]').format() == '[x][y][z][w]' + assert parse_expr(' x [ [ y[z[w]] ] ]').format() == 'x[[y[z[w]]]]' + assert parse_expr(' (x+y)[z+w] ').format() == '(x + y)[z + w]' + assert parse_expr('(&x)[3][(a-b)*2][&c]').format() == '(&x)[3][(a - b) * 2][&c]' + assert parse_expr('x[i+n*j]').format() == 'x[i + n * j]' + assert parse_expr('x+[y][z]').format() == 'x + [y][z]' + + assert parse_expr('[x][y][[z]]') == \ + ExprSubscript( + expr=ExprSubscript( + expr=ExprDeref(addr=ExprIdentifier(name='x')), + offset=ExprIdentifier(name='y') + ), + offset=ExprDeref(addr=ExprIdentifier(name='z'))) + + with pytest.raises(ParserError): + parse_expr('x[)]') + with pytest.raises(ParserError): + parse_expr('x[]') + + def test_operator_precedence(): code = '(5 + 2) - (3 - 9) * (7 + (-8)) - 10 * (-2) * 5 + (((7)))' expr = parse_expr(code) @@ -496,6 +541,48 @@ def test_format(args_str_wrong, args_str_right=''): test_format('(x #comment\n,y,z)->()') +def test_decoractor(): + code = """\ +@hello @world + + +@external func myfunc(): + return () +end""" + + assert parse_code_element(code=code).format(allowed_line_length=100) == """\ +@hello +@world +@external +func myfunc(): + return () +end""" + + +def test_decoractor_errors(): + verify_exception(""" +@hello world +func myfunc(): + return() +end +""", """ +file:?:?: Unexpected token Token(IDENTIFIER, \'world\'). Expected one of: "@", func. +@hello world + ^***^ +""") + + verify_exception(""" +@hello-world +func myfunc(): + return() +end +""", """ +file:?:?: Unexpected token Token(MINUS, \'-\'). Expected one of: "@", func. +@hello-world + ^ +""") + + def test_reference_type_annotation(): res = parse_code_element('let s : T * = ap') assert res.format(allowed_line_length=100) == 'let s : T* = ap' @@ -509,6 +596,16 @@ def test_addressof(): assert res.format(allowed_line_length=100) == 'static_assert &s.SIZE == ap' +def test_func_expr(): + res = parse_code_element('let x = f()') + assert isinstance(res, CodeElementReturnValueReference) + assert res.format(allowed_line_length=100) == 'let x = f()' + + res = parse_code_element('let x = (f())') + assert isinstance(res, CodeElementReference) + assert res.format(allowed_line_length=100) == 'let x = (f())' + + def test_locations(): code_with_marks = """\ [ap ] = [ fp + 2]; ap ++ @@ -535,6 +632,5 @@ def test_locations(): expr.body.b.addr.a, expr.body.b.addr.b, ] - assert len(exprs) == len(marks) - for expr, mark in zip(exprs, marks): + for expr, mark in safe_zip(exprs, marks): assert get_location_marks(code, expr.location) == code + '\n' + mark diff --git a/src/starkware/cairo/lang/compiler/parser_test_utils.py b/src/starkware/cairo/lang/compiler/parser_test_utils.py new file mode 100644 index 00000000..ec775d5c --- /dev/null +++ b/src/starkware/cairo/lang/compiler/parser_test_utils.py @@ -0,0 +1,15 @@ +import re + +import pytest + +from starkware.cairo.lang.compiler.parser import ParserError, parse_file + + +def verify_exception(code: str, error: str): + """ + Verifies that parsing the code results in the given error. + """ + with pytest.raises(ParserError) as e: + parse_file(code, '') + # Remove line and column information from the error using a regular expression. + assert re.sub(':[0-9]+:[0-9]+: ', 'file:?:?: ', str(e.value)) == error.strip() diff --git a/src/starkware/cairo/lang/compiler/parser_transformer.py b/src/starkware/cairo/lang/compiler/parser_transformer.py index 557fbff3..c7f80686 100644 --- a/src/starkware/cairo/lang/compiler/parser_transformer.py +++ b/src/starkware/cairo/lang/compiler/parser_transformer.py @@ -15,10 +15,11 @@ CodeElementLabel, CodeElementLocalVariable, CodeElementMember, CodeElementReference, CodeElementReturn, CodeElementReturnValueReference, CodeElementStaticAssert, CodeElementTailCall, CodeElementTemporaryVariable, CodeElementUnpackBinding, CodeElementWith, - CommentedCodeElement) + CommentedCodeElement, LangDirective) from starkware.cairo.lang.compiler.ast.expr import ( - ArgList, ExprAddressOf, ExprAssignment, ExprCast, ExprConst, ExprDeref, ExprIdentifier, ExprNeg, - ExprOperator, ExprParentheses, ExprPyConst, ExprReg, ExprTuple) + ArgList, ExprAddressOf, ExprAssignment, ExprCast, ExprConst, ExprDeref, ExprDot, ExprIdentifier, + ExprNeg, ExprOperator, ExprParentheses, ExprPyConst, ExprReg, ExprSubscript, ExprTuple) +from starkware.cairo.lang.compiler.ast.expr_func_call import ExprFuncCall from starkware.cairo.lang.compiler.ast.instructions import ( AddApInstruction, AssertEqInstruction, CallInstruction, CallLabelInstruction, InstructionAst, JnzInstruction, JumpInstruction, JumpToLabelInstruction, RetInstruction) @@ -118,6 +119,10 @@ def atom_pyconst(self, value, meta): def atom_reg(self, value, meta): return ExprReg(reg=value[0], location=self.meta2loc(meta)) + @v_args(meta=True) + def atom_func_call(self, value, meta): + return ExprFuncCall(rvalue=value[0], location=self.meta2loc(meta)) + @v_args(meta=True) def expr_add(self, value, meta): return ExprOperator( @@ -154,6 +159,15 @@ def atom_parentheses(self, value, meta): def atom_deref(self, value, meta): return ExprDeref(addr=value[1], notes=value[0], location=self.meta2loc(meta)) + @v_args(meta=True) + def atom_subscript(self, value, meta): + return ExprSubscript( + expr=value[0], offset=value[2], notes=value[1], location=self.meta2loc(meta)) + + @v_args(meta=True) + def atom_dot(self, value, meta): + return ExprDot(expr=value[0], member=value[1], location=self.meta2loc(meta)) + @v_args(meta=True) def atom_cast(self, value, meta): return ExprCast( @@ -398,17 +412,19 @@ def implicit_arguments(self, value): else: raise NotImplementedError(f'Unexpected argument: value={value}') - def code_element_function(self, value): - identifier, implicit_arguments, arguments = value[:3] + def decorator_list(self, value): + return value - if len(value) == 5: + def code_element_function(self, value): + decorators, identifier, implicit_arguments, arguments = value[:4] + if len(value) == 6: # Return values present. - returns = value[3] - code_block = value[4] - elif len(value) == 4: + returns = value[4] + code_block = value[5] + elif len(value) == 5: # Return values not present. returns = None - code_block = value[3] + code_block = value[4] else: raise NotImplementedError(f'Unexpected argument: value={value}') @@ -419,6 +435,7 @@ def code_element_function(self, value): implicit_arguments=implicit_arguments, returns=returns, code_block=code_block, + decorators=decorators, ) def code_element_struct(self, value): @@ -430,6 +447,7 @@ def code_element_struct(self, value): implicit_arguments=None, returns=None, code_block=code_block, + decorators=[], ) def code_element_with(self, value): @@ -474,6 +492,10 @@ def directive_builtins(self, value, meta): builtins = [ident.name for ident in value] return BuiltinsDirective(builtins=builtins, location=self.meta2loc(meta)) + @v_args(meta=True) + def directive_lang(self, value, meta): + return LangDirective(name=value[0].name, location=self.meta2loc(meta)) + @v_args(meta=True) def aliased_identifier(self, value, meta): if len(value) == 1: diff --git a/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py b/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py index 739fc21a..d7999e77 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py @@ -1,5 +1,4 @@ import itertools -import re from typing import Optional import pytest @@ -11,7 +10,7 @@ CompoundExpressionContext, CompoundExpressionVisitor, SimplicityLevel, process_compound_expressions) from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import ( - PRIME, preprocess_str, verify_exception) + PRIME, preprocess_str, strip_comments_and_linebreaks, verify_exception) class CompoundExpressionTestContext(CompoundExpressionContext): @@ -240,7 +239,7 @@ def test_compound_expressions_long(): [ap] = [ap + (-2)] * [ap + (-1)]; ap++ # Compute x * x. [ap + (-10)] + [ap + (-4)] = [ap + (-1)] # Assert x + y * z + x / (-x - (y - z)) = x * x. """ - assert program.format() == re.sub(r'\s*#.*\n', '\n', expected_result) + assert program.format() == strip_comments_and_linebreaks(expected_result) def test_compound_expressions_tempvars(): @@ -280,7 +279,7 @@ def test_compound_expressions_localvar(): ret """ - assert program.format() == re.sub(r'\s*#.*\n', '\n', expected_result).replace('\n\n', '\n') + assert program.format() == strip_comments_and_linebreaks(expected_result) def test_compound_expressions_args(): @@ -311,7 +310,7 @@ def test_compound_expressions_args(): [ap] = [ap + (-5)] + [ap + (-4)]; ap++ # Push 3 * x + x * x. call rel -15 """ - assert program.format() == re.sub(r'\s*#.*\n', '\n', expected_result).replace('\n\n', '\n') + assert program.format() == strip_comments_and_linebreaks(expected_result) def test_compound_expressions_failures(): @@ -347,16 +346,6 @@ def test_compound_expressions_failures(): struct T: member a : felt end -assert cast([ap], T) = cast([ap], T) -""", """ -file:?:?: Expected a 'felt' or a pointer type. Got: 'test_scope.T'. -assert cast([ap], T) = cast([ap], T) - ^***********^ -""") - verify_exception("""\ -struct T: - member a : felt -end assert 7 = cast(7, T*) """, """ file:?:?: Cannot compare 'felt' and 'test_scope.T*'. diff --git a/src/starkware/cairo/lang/compiler/preprocessor/conftest.py b/src/starkware/cairo/lang/compiler/preprocessor/conftest.py index 98c3adc3..98e6bbb1 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/conftest.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/conftest.py @@ -1,3 +1,6 @@ import pytest +# Instruct pytest to print full information (e.g., the values on both sides of the equality) +# about asserts that failed in the module below. +# Normally, pytest prints full information only for test files (according to their name). pytest.register_assert_rewrite('starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils') diff --git a/src/starkware/cairo/lang/compiler/preprocessor/default_pass_manager.py b/src/starkware/cairo/lang/compiler/preprocessor/default_pass_manager.py new file mode 100644 index 00000000..468aea5f --- /dev/null +++ b/src/starkware/cairo/lang/compiler/preprocessor/default_pass_manager.py @@ -0,0 +1,96 @@ +from typing import Callable, Dict, Optional, Sequence, Tuple, Type + +from starkware.cairo.lang.compiler.ast.module import CairoModule +from starkware.cairo.lang.compiler.import_loader import collect_imports +from starkware.cairo.lang.compiler.preprocessor.dependency_graph import DependencyGraphStage +from starkware.cairo.lang.compiler.preprocessor.identifier_collector import IdentifierCollector +from starkware.cairo.lang.compiler.preprocessor.pass_manager import ( + PassManager, PassManagerContext, Stage, VisitorStage) +from starkware.cairo.lang.compiler.preprocessor.preprocessor import Preprocessor +from starkware.cairo.lang.compiler.preprocessor.struct_collector import StructCollector +from starkware.cairo.lang.compiler.preprocessor.unique_labels import UniqueLabelCreator +from starkware.cairo.lang.compiler.scoped_name import ScopedName + + +def default_pass_manager( + prime: int, + read_module: Callable[[str], Tuple[str, str]], + preprocessor_cls: Optional[Type[Preprocessor]] = None, + opt_unused_functions: bool = True, + preprocessor_kwargs: Optional[Dict] = None) -> PassManager: + manager = PassManager() + manager.add_stage('module_collector', ModuleCollector(read_module=read_module)) + manager.add_stage('unique_label_creator', VisitorStage( + lambda context: UniqueLabelCreator(), modify_ast=True)) + manager.add_stage('identifier_collector', VisitorStage( + lambda context: IdentifierCollector(identifiers=context.identifiers))) + if opt_unused_functions: + manager.add_stage('dependency_graph', DependencyGraphStage()) + manager.add_stage('struct_collector', VisitorStage( + lambda context: StructCollector(identifiers=context.identifiers))) + manager.add_stage('preprocessor', PreprocessorStage( + prime, preprocessor_cls, preprocessor_kwargs)) + return manager + + +class PreprocessorStage(Stage): + def __init__( + self, prime: int, preprocessor_cls: Optional[Type[Preprocessor]] = None, + preprocessor_kwargs: Optional[Dict] = None): + self.prime = prime + if preprocessor_cls is None: + self.preprocessor_cls = Preprocessor + else: + self.preprocessor_cls = preprocessor_cls + self.preprocessor_kwargs = {} if preprocessor_kwargs is None else preprocessor_kwargs + + def run(self, context: PassManagerContext): + preprocessor = self.preprocessor_cls( + prime=self.prime, identifiers=context.identifiers, + functions_to_compile=context.functions_to_compile, **self.preprocessor_kwargs) + preprocessor.identifier_locations = context.identifier_locations + + for module in context.modules: + preprocessor.visit(module) + + preprocessor.resolve_labels() + context.preprocessed_program = preprocessor.get_program() + + +class ModuleCollector(Stage): + def __init__( + self, read_module: Callable[[str], Tuple[str, str]], + additional_modules: Optional[Sequence[str]] = None): + self.read_module = read_module + self.additional_modules = [] if additional_modules is None else list(additional_modules) + + def run(self, context: PassManagerContext): + visited_modules = set() + + for additional_module in self.additional_modules: + files = collect_imports(additional_module, read_file=self.read_module) + for module_name, ast in files.items(): + if module_name in visited_modules: + continue + visited_modules.add(module_name) + scope = ScopedName.from_string(module_name) + context.modules.append(CairoModule(cairo_file=ast, module_name=scope)) + + for code, filename in context.codes: + # Function used to read files given module names. + # The root module (filename) is handled separately, for this module code is returned. + def read_file_fixed(name): + return (code, filename) if name == filename else self.read_module(name) + + files = collect_imports(filename, read_file=read_file_fixed) + for module_name, ast in files.items(): + # Check if the module is one of the files given in 'context.codes'. + is_main_scope = module_name == filename + if is_main_scope: + scope = context.main_scope + else: + scope = ScopedName.from_string(module_name) + if module_name in visited_modules: + continue + visited_modules.add(module_name) + context.modules.append(CairoModule(cairo_file=ast, module_name=scope)) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph.py b/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph.py index 75b5866c..be5e952a 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph.py @@ -1,9 +1,15 @@ -from typing import Dict, List +from typing import Dict, List, Optional, Set -from starkware.cairo.lang.compiler.ast.code_elements import CodeElementImport -from starkware.cairo.lang.compiler.ast.expr import ExprIdentifier +from starkware.cairo.lang.compiler.ast.code_elements import CodeElementFunction, CodeElementImport +from starkware.cairo.lang.compiler.ast.expr import ExprAssignment, ExprDot, ExprIdentifier +from starkware.cairo.lang.compiler.ast.module import CairoModule from starkware.cairo.lang.compiler.ast.visitor import Visitor -from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager +from starkware.cairo.lang.compiler.error_handling import Location +from starkware.cairo.lang.compiler.identifier_definition import AliasDefinition +from starkware.cairo.lang.compiler.identifier_manager import ( + IdentifierManager, MissingIdentifierError) +from starkware.cairo.lang.compiler.preprocessor.pass_manager import PassManagerContext, Stage +from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -18,28 +24,138 @@ def __init__(self, identifiers: IdentifierManager): self.identifiers = identifiers # A dictionary from a scope name to the list of identifiers it uses. self.visited_identifiers: Dict[ScopedName, List[ScopedName]] = {} + # Tracks the current function being visited. + self.current_function: Optional[ScopedName] = None def _visit_default(self, obj): for child in obj.get_children(): if child is not None: self.visit(child) - def add_identifier(self, name: ScopedName, is_resolved: bool = False): + def add_identifier( + self, name: ScopedName, location: Optional[Location], is_resolved: bool = False): + if name.path[-1] == '_': + return if is_resolved: canonical_name = name else: - canonical_name = self.identifiers.search( - accessible_scopes=self.accessible_scopes, name=name).canonical_name + try: + canonical_name = self.identifiers.search( + accessible_scopes=self.accessible_scopes, name=name).canonical_name + except MissingIdentifierError as e: + raise PreprocessorError(str(e), location=location) - self.visited_identifiers.setdefault(self.current_scope, []).append( - canonical_name) + if self.current_function is not None: + self.visited_identifiers.setdefault(self.current_function, []).append( + canonical_name) + + def visit_CodeElementMember(self, elm): + pass + + def visit_ExprDot(self, expr: ExprDot): + # We override the default visitor, since we must not visit expr.member. + self.visit(expr.expr) + + def visit_CodeElementFunction(self, elm: CodeElementFunction): + if elm.element_type == 'func': + # Update self.current_function. + old_current_function = self.current_function + try: + self.current_function = self.current_scope + elm.name + # Enforce that every function appears in visited_identifiers. + self.visited_identifiers.setdefault(self.current_scope + elm.name, []) + super().visit_CodeElementFunction(elm) + finally: + self.current_function = old_current_function + else: + super().visit_CodeElementFunction(elm) + + def visit_ExprAssignment(self, elm: ExprAssignment): + # We override the default visitor, since we must not visit expr.identifier. + self.visit(elm.expr) def visit_ExprIdentifier(self, expr: ExprIdentifier): - self.add_identifier(ScopedName.from_string(expr.name)) + self.add_identifier(ScopedName.from_string(expr.name), location=expr.location) def visit_CodeElementImport(self, code_elm: CodeElementImport): for import_item in code_elm.import_items: self.add_identifier( ScopedName.from_string(code_elm.path.name) + ScopedName.from_string(import_item.orig_identifier.name), - is_resolved=True) + is_resolved=True, + location=code_elm.location) + + def find_function_dependencies(self, functions: Set[ScopedName]) -> Set[ScopedName]: + """ + Finds all the transitive dependencies of a given set of functions. + """ + finder = FunctionDependencyFinder(self.visited_identifiers) + for x in functions: + if x not in self.visited_identifiers: + continue + finder.visit(x) + return finder.visited + + +class FunctionDependencyFinder: + """ + A class helper for find_function_dependencies. + """ + + def __init__(self, identifer_dependencies: Dict[ScopedName, List[ScopedName]]): + self.identifer_dependencies = identifer_dependencies + self.visited: Set[ScopedName] = set() + + def visit(self, name: ScopedName): + # Find the largest prefix that is a function. + while len(name.path) > 0 and name not in self.identifer_dependencies: + name = name[:-1] + if name not in self.identifer_dependencies: + # No such function. + return + if name in self.visited: + return + self.visited.add(name) + for identifier in self.identifer_dependencies[name]: + self.visit(identifier) + + +def get_main_functions_to_compile( + identifiers: IdentifierManager, main_scope: ScopedName) -> Set[ScopedName]: + """ + Retrieves the root functions to compile from a main scope. + The definition of which functions we need to compile is somewhat arbitrary: + All functions explicitly defined, or aliased in the main scope. + """ + main_functions: Set[ScopedName] = set() + try: + scope = identifiers.get_scope(main_scope) + main_functions = {main_scope + name for name in scope.subscopes} + main_functions |= { + identifier_definition.destination + for identifier_definition in scope.identifiers.values() + if isinstance(identifier_definition, AliasDefinition)} + except MissingIdentifierError: + return set() + return main_functions + + +def get_functions_to_compile( + modules: List[CairoModule], identifiers: IdentifierManager, + main_scope: ScopedName) -> Set[ScopedName]: + """ + Returns a set of reachable function (starting from the functions in the main scope). + """ + + dependency_graph = DependencyGraphVisitor(identifiers) + for module in modules: + dependency_graph.visit(module) + return dependency_graph.find_function_dependencies(get_main_functions_to_compile( + identifiers=identifiers, main_scope=main_scope)) + + +class DependencyGraphStage(Stage): + def run(self, context: PassManagerContext): + assert context.functions_to_compile is None + context.functions_to_compile = get_functions_to_compile( + modules=context.modules, identifiers=context.identifiers, main_scope=context.main_scope) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph_test.py b/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph_test.py index 16276853..8291b8e8 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph_test.py @@ -1,16 +1,19 @@ -from typing import Dict, Set +from typing import Dict from starkware.cairo.lang.compiler.ast.module import CairoModule from starkware.cairo.lang.compiler.parser import parse_file -from starkware.cairo.lang.compiler.preprocessor.dependency_graph import DependencyGraphVisitor +from starkware.cairo.lang.compiler.preprocessor.dependency_graph import ( + DependencyGraphVisitor, get_main_functions_to_compile) from starkware.cairo.lang.compiler.preprocessor.identifier_collector import IdentifierCollector from starkware.cairo.lang.compiler.scoped_name import ScopedName +scope = ScopedName.from_string -def _extract_dependency_graph(codes: Dict[str, str]) -> Dict[str, Set[str]]: + +def _extract_dependency_graph(codes: Dict[str, str]) -> DependencyGraphVisitor: """ Extracts the dependencies from the given codes (given as a map from a file name to its content). - Returns the dependencies as a map from scope name to a set of the identifiers it uses. + Returns the DependencyGraphVisitor instance. """ modules = [ CairoModule( @@ -23,15 +26,14 @@ def _extract_dependency_graph(codes: Dict[str, str]) -> Dict[str, Set[str]]: dependency_graph_visitor = DependencyGraphVisitor(identifiers=identifier_collector.identifiers) for module in modules: dependency_graph_visitor.visit(module) - return { - str(scope): set(map(str, deps)) - for scope, deps in dependency_graph_visitor.visited_identifiers.items()} + return dependency_graph_visitor def test_dependency_graph(): - modules = {'module': """ -func func0(): - return () + modules = { + 'module': """ +func func0() -> (res): + return (res=0) end func func1(): return () @@ -39,7 +41,11 @@ def test_dependency_graph(): func func2(): return () end -""", '__main__': """ +func func3(): + return () +end +""", + '__main__': """ from module import func1 as func1_alias func foo(): @@ -61,12 +67,16 @@ def test_dependency_graph(): let _reference = [fp] + 2 _label: - let _typed_reference : W = myfunc(1, 2, 3) + let _typed_reference : W = ns.myfunc(1, 2, 3) end -func myfunc(): - myfunc() - func1_alias() +namespace ns: + func myfunc(): + myfunc() + func1_alias() + end + + call bar # This line will be ignored since it's outside of any function. end struct W: @@ -85,12 +95,18 @@ def test_dependency_graph(): jmp foo._label call bar end + +call bar # This line will be ignored since it's outside of any function. +""", + '': """ +from module import func2 """} - assert _extract_dependency_graph(modules) == { - '__main__': { - 'module.func1', - }, + dependency_graph_visitor = _extract_dependency_graph(modules) + dependencies = { + str(scope): set(map(str, deps)) + for scope, deps in dependency_graph_visitor.visited_identifiers.items()} + assert dependencies == { '__main__.foo': { '__main__.foo._tempvar', '__main__.foo._const', @@ -98,12 +114,12 @@ def test_dependency_graph(): '__main__.foo._reference', '__main__.foo._label', '__main__.foo._typed_reference', - '__main__.myfunc', + '__main__.ns.myfunc', 'module.func0', 'module.func2', }, - '__main__.myfunc': { - '__main__.myfunc', + '__main__.ns.myfunc': { + '__main__.ns.myfunc', 'module.func1', }, '__main__.bar': { @@ -118,4 +134,62 @@ def test_dependency_graph(): '__main__.bar', '__main__.foo._label', }, + 'module.func0': set(), + 'module.func1': set(), + 'module.func2': set(), + 'module.func3': set(), + } + + assert dependency_graph_visitor.find_function_dependencies( + {scope('__main__.main')}) == { + ScopedName(path=('__main__', 'bar')), + ScopedName(path=('__main__', 'foo')), + ScopedName(path=('__main__', 'main')), + ScopedName(path=('__main__', 'ns', 'myfunc')), + ScopedName(path=('module', 'func0')), + ScopedName(path=('module', 'func1')), + ScopedName(path=('module', 'func2')), + } + assert dependency_graph_visitor.find_function_dependencies( + {scope('__main__.ns.myfunc')}) == { + ScopedName(path=('__main__', 'ns', 'myfunc')), + ScopedName(path=('module', 'func1')), + } + assert dependency_graph_visitor.find_function_dependencies( + {scope('__main__.ns.myfunc'), scope('__main__.bar')}) == { + ScopedName(path=('__main__', 'bar')), + ScopedName(path=('__main__', 'foo')), + ScopedName(path=('__main__', 'ns', 'myfunc')), + ScopedName(path=('module', 'func0')), + ScopedName(path=('module', 'func1')), + ScopedName(path=('module', 'func2')), + } + assert dependency_graph_visitor.find_function_dependencies( + {scope('foo')}) == set() + + # Test get_main_functions_to_compile(). + + assert get_main_functions_to_compile( + identifiers=dependency_graph_visitor.identifiers, + main_scope=scope('module')) == { + scope('module.func0'), + scope('module.func1'), + scope('module.func2'), + scope('module.func3'), + } + assert get_main_functions_to_compile( + identifiers=dependency_graph_visitor.identifiers, + main_scope=scope('__main__')) == { + scope('module.func1'), + scope('__main__.foo'), + scope('__main__.ns'), + scope('__main__.bar'), + scope('__main__.main'), + } + assert get_main_functions_to_compile( + identifiers=dependency_graph_visitor.identifiers, + main_scope=scope('')) == { + scope('module.func2'), + scope('module'), + scope('__main__'), } diff --git a/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor.py b/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor.py index c8ac8dca..7d909dca 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor.py @@ -9,8 +9,7 @@ from starkware.cairo.lang.compiler.identifier_definition import ( DefinitionError, FutureIdentifierDefinition, IdentifierDefinition, StructDefinition) from starkware.cairo.lang.compiler.identifier_manager import IdentifierError, IdentifierManager -from starkware.cairo.lang.compiler.identifier_utils import ( - get_struct_definition, resolve_search_result) +from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -33,7 +32,8 @@ def handle_missing_future_definition(self, name: ScopedName, location): location=location) def add_name_definition( - self, name: ScopedName, identifier_definition: IdentifierDefinition, location): + self, name: ScopedName, identifier_definition: IdentifierDefinition, location, + require_future_definition=True): """ Adds a definition of an identifier named 'name' at 'location'. The identifier must already be found as a FutureIdentifierDefinition in 'self.identifiers' @@ -42,7 +42,8 @@ def add_name_definition( future_definition = self.identifiers.get_by_full_name(name) if future_definition is None: - self.handle_missing_future_definition(name=name, location=location) + if require_future_definition: + self.handle_missing_future_definition(name=name, location=location) else: if not isinstance(future_definition, FutureIdentifierDefinition): raise PreprocessorError(f"Redefinition of '{name}'.", location=location) @@ -66,10 +67,10 @@ def get_struct_definition( try: res = self.identifiers.search( accessible_scopes=self.accessible_scopes, name=name) + res.assert_fully_parsed() except IdentifierError as exc: raise PreprocessorError(str(exc), location=location) - res.assert_fully_parsed() struct_def = res.identifier_definition if not isinstance(struct_def, StructDefinition): raise PreprocessorError( @@ -79,17 +80,14 @@ def get_struct_definition( return struct_def - def search_identifier( - self, name: str, location: Optional[Location]) -> Optional[IdentifierDefinition]: + def try_get_struct_definition(self, name: ScopedName) -> Optional[StructDefinition]: """ - Searches for the given identifier in self.identifiers and returns the corresponding - IdentifierDefinition. + Same as get_struct_definition() except that None is returned in case of a failure. """ try: - result = self.identifiers.search(self.accessible_scopes, ScopedName.from_string(name)) - return resolve_search_result(result, identifiers=self.identifiers) - except IdentifierError as exc: - raise PreprocessorError(str(exc), location=location) + return self.get_struct_definition(name, None) + except PreprocessorError: + return None def get_canonical_struct_name(self, scoped_name: ScopedName, location: Optional[Location]): """ @@ -142,8 +140,9 @@ def resolve_type(self, cairo_type: CairoType) -> CairoType: except IdentifierError as exc: raise PreprocessorError(str(exc), location=cairo_type.location) elif isinstance(cairo_type, TypeTuple): - raise PreprocessorError( - 'Tuples are not supported yet.', location=cairo_type.location) + return dataclasses.replace( + cairo_type, + members=[self.resolve_type(subtype) for subtype in cairo_type.members]) else: raise NotImplementedError(f'Type {type(cairo_type).__name__} is not supported.') @@ -167,8 +166,7 @@ def get_size(self, cairo_type: CairoType): return self.get_struct_size( struct_name=cairo_type.scope, location=cairo_type.location) elif isinstance(cairo_type, TypeTuple): - raise PreprocessorError( - 'Tuples are not supported yet.', location=cairo_type.location) + return sum(self.get_size(member_type) for member_type in cairo_type.members) else: raise NotImplementedError(f'Type {type(cairo_type).__name__} is not supported.') diff --git a/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor_test.py b/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor_test.py new file mode 100644 index 00000000..2153a8d4 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor_test.py @@ -0,0 +1,23 @@ +import pytest + +from starkware.cairo.lang.compiler.identifier_definition import ConstDefinition +from starkware.cairo.lang.compiler.preprocessor.identifier_aware_visitor import ( + IdentifierAwareVisitor) +from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError +from starkware.cairo.lang.compiler.scoped_name import ScopedName + + +def test_add_name_definition_no_future(): + visitor = IdentifierAwareVisitor() + + test_id = ScopedName.from_string('test_id') + location = None + + visitor.add_name_definition( + name=test_id, identifier_definition=ConstDefinition(value=1), location=location, + require_future_definition=False) + + with pytest.raises(PreprocessorError, match=f"Redefinition of 'test_id'."): + visitor.add_name_definition( + name=test_id, identifier_definition=ConstDefinition(value=1), location=location, + require_future_definition=False) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py b/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py index baa01873..9cdcdd77 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py @@ -9,9 +9,9 @@ from starkware.cairo.lang.compiler.ast.visitor import Visitor from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.identifier_definition import ( - AliasDefinition, ConstDefinition, FutureIdentifierDefinition, IdentifierDefinition, - LabelDefinition, ReferenceDefinition, StructDefinition) -from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager + AliasDefinition, ConstDefinition, FunctionDefinition, FutureIdentifierDefinition, + IdentifierDefinition, LabelDefinition, ReferenceDefinition, StructDefinition) +from starkware.cairo.lang.compiler.identifier_manager import IdentifierError, IdentifierManager from starkware.cairo.lang.compiler.preprocessor.local_variables import N_LOCALS_CONSTANT from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -45,9 +45,9 @@ class IdentifierCollector(Visitor): CodeElementReturnValueReference: ReferenceDefinition, } - def __init__(self): + def __init__(self, identifiers: Optional[IdentifierManager] = None): super().__init__() - self.identifiers = IdentifierManager() + self.identifiers = IdentifierManager() if identifiers is None else identifiers def add_identifier( self, name: ScopedName, identifier_definition: IdentifierDefinition, @@ -152,8 +152,9 @@ def handle_function_arguments( 'argument.', location=arg_id.location) + ident_type = FunctionDefinition if elm.element_type == 'func' else LabelDefinition self.add_future_identifier( - function_scope, LabelDefinition, elm.identifier.location) + function_scope, ident_type, elm.identifier.location) # Add SIZEOF_LOCALS for current block at identifier definition location if available. self.add_future_identifier( @@ -200,10 +201,13 @@ def visit_CodeElementImport(self, elm: CodeElementImport): # Ensure destination is a valid identifier. if self.identifiers.get_by_full_name(alias_dst) is None: - raise PreprocessorError( - f"Scope '{elm.path.name}' does not include identifier " - f"'{import_item.orig_identifier.name}'.", - location=import_item.orig_identifier.location) + try: + self.identifiers.get_scope(alias_dst) + except IdentifierError: + raise PreprocessorError( + f"Cannot import '{import_item.orig_identifier.name}' " + f"from '{elm.path.name}'.", + location=import_item.orig_identifier.location) # Add alias to identifiers. self.add_identifier( diff --git a/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector_test.py b/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector_test.py index 47e8282e..75cf3d4e 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector_test.py @@ -1,5 +1,6 @@ from starkware.cairo.lang.compiler.identifier_definition import ( - AliasDefinition, ConstDefinition, LabelDefinition, ReferenceDefinition, StructDefinition) + AliasDefinition, ConstDefinition, FunctionDefinition, LabelDefinition, ReferenceDefinition, + StructDefinition) from starkware.cairo.lang.compiler.parser import parse_file from starkware.cairo.lang.compiler.preprocessor.identifier_collector import IdentifierCollector from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import verify_exception @@ -46,7 +47,7 @@ def test_collect_multi_binds(): let (e, f) = g() """ assert set(_extract_identifiers(code)) == { - ('a', LabelDefinition), + ('a', FunctionDefinition), ('a.SIZEOF_LOCALS', ConstDefinition), ('a.Args', StructDefinition), ('a.ImplicitArgs', StructDefinition), @@ -68,7 +69,7 @@ def test_nested_funcs(): end """ assert set(_extract_identifiers(code)) == { - ('foo', LabelDefinition), + ('foo', FunctionDefinition), ('foo.SIZEOF_LOCALS', ConstDefinition), ('foo.Args', StructDefinition), ('foo.ImplicitArgs', StructDefinition), @@ -76,7 +77,7 @@ def test_nested_funcs(): ('foo.x', ReferenceDefinition), ('foo.z', ReferenceDefinition), ('foo.a', ReferenceDefinition), - ('foo.bar', LabelDefinition), + ('foo.bar', FunctionDefinition), ('foo.bar.SIZEOF_LOCALS', ConstDefinition), ('foo.bar.Args', StructDefinition), ('foo.bar.ImplicitArgs', StructDefinition), diff --git a/src/starkware/cairo/lang/compiler/preprocessor/pass_manager.py b/src/starkware/cairo/lang/compiler/preprocessor/pass_manager.py new file mode 100644 index 00000000..066bd7b1 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/preprocessor/pass_manager.py @@ -0,0 +1,104 @@ +import dataclasses +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Set, Tuple + +from starkware.cairo.lang.compiler.ast.module import CairoModule +from starkware.cairo.lang.compiler.error_handling import Location +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager +from starkware.cairo.lang.compiler.preprocessor.preprocessor import PreprocessedProgram +from starkware.cairo.lang.compiler.scoped_name import ScopedName + + +@dataclasses.dataclass +class PassManagerContext: + # A list of pairs (code, filename). + codes: List[Tuple[str, str]] + main_scope: ScopedName + identifiers: IdentifierManager + modules: List[CairoModule] = dataclasses.field(default_factory=list) + identifier_locations: Dict[ScopedName, Location] = dataclasses.field(default_factory=dict) + preprocessed_program: Optional[PreprocessedProgram] = None + # A set of functions to compile (None means all functions will be compiled). + # If the unused function optimization is enabled, only reachable functions will be compiled. + functions_to_compile: Optional[Set[ScopedName]] = None + + +class Stage(ABC): + """ + Represents a compilation stage. + """ + + @abstractmethod + def run(self, context: PassManagerContext): + """ + Runs the stage on the given context. The stage may modify context. + """ + + +class PassManager: + """ + Manages the preprocessor's stages. + """ + + def __init__(self): + # The list of stages. + self.stages: List[Tuple[str, Stage]] = [] + # A set of stage names. + self.stage_names: Set[str] = set() + + def run(self, context: PassManagerContext): + for _, stage in self.stages: + stage.run(context) + + def get_stage_index(self, name: str): + assert name in self.stage_names + index, = [i for i, (stage_name, _) in enumerate(self.stages) if stage_name == name] + return index + + # Functions for manipulating the stages: + + def add_stage(self, new_stage_name: str, new_stage: Stage, index: Optional[int] = None): + """ + Adds a stage at the end. + """ + assert new_stage_name not in self.stage_names + if index is None: + index = len(self.stages) + self.stages.insert(index, (new_stage_name, new_stage)) + self.stage_names.add(new_stage_name) + + def add_before(self, existing_stage: str, new_stage_name: str, new_stage: Stage): + """ + Adds a new stage before 'existing_stage'. + """ + self.add_stage(new_stage_name, new_stage, index=self.get_stage_index(existing_stage)) + + def add_after(self, existing_stage: str, new_stage_name: str, new_stage: Stage): + """ + Adds a new stage after 'existing_stage'. + """ + self.add_stage(new_stage_name, new_stage, index=self.get_stage_index(existing_stage) + 1) + + def replace(self, existing_stage: str, new_stage: Stage): + """ + Replaces 'existing_stage' with the given stage. + """ + self.stages[self.get_stage_index(existing_stage)] = (existing_stage, new_stage) + + +class VisitorStage(Stage): + """ + A generic stage that runs a visitor on the AST. + """ + + def __init__(self, visitor_factory, modify_ast=False): + self.visitor_factory = visitor_factory + self.modify_ast = modify_ast + + def run(self, context: PassManagerContext): + visitor = self.visitor_factory(context) + modified_modules = [] + for module in context.modules: + modified_modules.append(visitor.visit(module)) + if self.modify_ast: + context.modules = modified_modules diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocess_codes.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocess_codes.py new file mode 100644 index 00000000..79b29c8d --- /dev/null +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocess_codes.py @@ -0,0 +1,25 @@ +from typing import Sequence, Tuple + +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager +from starkware.cairo.lang.compiler.preprocessor.pass_manager import PassManager, PassManagerContext +from starkware.cairo.lang.compiler.preprocessor.preprocessor import PreprocessedProgram +from starkware.cairo.lang.compiler.scoped_name import ScopedName + + +def preprocess_codes( + codes: Sequence[Tuple[str, str]], pass_manager: PassManager, + main_scope: ScopedName = ScopedName()) -> PreprocessedProgram: + """ + Preprocesses a list of Cairo files and returns a PreprocessedProgram instance. + codes is a list of pairs (code_string, file_name). + """ + context = PassManagerContext( + codes=list(codes), + main_scope=main_scope, + identifiers=IdentifierManager(), + ) + + pass_manager.run(context) + + assert context.preprocessed_program is not None + return context.preprocessed_program diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py index aee2b5c2..bad8f1c9 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py @@ -1,21 +1,23 @@ import dataclasses from contextlib import contextmanager from enum import Enum, auto -from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Type, cast +from typing import Dict, List, Optional, Set, Tuple, cast from starkware.cairo.lang.compiler.ast.arguments import IdentifierList from starkware.cairo.lang.compiler.ast.cairo_types import ( - CairoType, CastType, TypeFelt, TypePointer, TypeStruct) + CairoType, CastType, TypeFelt, TypePointer, TypeStruct, TypeTuple) from starkware.cairo.lang.compiler.ast.code_elements import ( BuiltinsDirective, CodeBlock, CodeElement, CodeElementAllocLocals, CodeElementCompoundAssertEq, CodeElementConst, CodeElementDirective, CodeElementEmptyLine, CodeElementFuncCall, CodeElementFunction, CodeElementHint, CodeElementIf, CodeElementImport, CodeElementInstruction, CodeElementLabel, CodeElementLocalVariable, CodeElementMember, CodeElementReference, CodeElementReturn, CodeElementReturnValueReference, CodeElementStaticAssert, - CodeElementTailCall, CodeElementTemporaryVariable, CodeElementUnpackBinding, CodeElementWith) + CodeElementTailCall, CodeElementTemporaryVariable, CodeElementUnpackBinding, CodeElementWith, + LangDirective) from starkware.cairo.lang.compiler.ast.expr import ( ExprAssignment, ExprCast, ExprConst, ExprDeref, Expression, ExprFutureLabel, ExprIdentifier, - ExprOperator, ExprReg) + ExprOperator, ExprReg, ExprTuple) +from starkware.cairo.lang.compiler.ast.expr_func_call import ExprFuncCall from starkware.cairo.lang.compiler.ast.formatting_utils import get_max_line_length from starkware.cairo.lang.compiler.ast.instructions import ( AddApInstruction, AssertEqInstruction, CallInstruction, CallLabelInstruction, InstructionAst, @@ -26,11 +28,10 @@ from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.expression_simplifier import ExpressionSimplifier from starkware.cairo.lang.compiler.identifier_definition import ( - ConstDefinition, DefinitionError, FutureIdentifierDefinition, LabelDefinition, MemberDefinition, - ReferenceDefinition) + ConstDefinition, DefinitionError, FunctionDefinition, FutureIdentifierDefinition, + IdentifierDefinition, LabelDefinition, MemberDefinition, ReferenceDefinition) from starkware.cairo.lang.compiler.identifier_manager import IdentifierError, IdentifierManager from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition -from starkware.cairo.lang.compiler.import_loader import collect_imports from starkware.cairo.lang.compiler.instruction import Register from starkware.cairo.lang.compiler.instruction_builder import ( InstructionBuilderError, get_instruction_size) @@ -44,20 +45,19 @@ ReferenceManager) from starkware.cairo.lang.compiler.preprocessor.identifier_aware_visitor import ( IdentifierAwareVisitor) -from starkware.cairo.lang.compiler.preprocessor.identifier_collector import IdentifierCollector from starkware.cairo.lang.compiler.preprocessor.local_variables import ( create_simple_ref_expr, preprocess_local_variables) from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.cairo.lang.compiler.preprocessor.preprocessor_utils import assert_no_modifier from starkware.cairo.lang.compiler.preprocessor.reg_tracking import ( RegChange, RegChangeKnown, RegChangeUnconstrained, RegChangeUnknown, RegTrackingData) -from starkware.cairo.lang.compiler.preprocessor.struct_collector import StructCollector -from starkware.cairo.lang.compiler.preprocessor.unique_labels import UniqueLabelCreator from starkware.cairo.lang.compiler.references import FlowTrackingError, Reference, translate_ap +from starkware.cairo.lang.compiler.resolve_search_result import resolve_search_result from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.substitute_identifiers import substitute_identifiers from starkware.cairo.lang.compiler.type_casts import check_cast from starkware.cairo.lang.compiler.type_system_visitor import get_expr_addr, simplify_type_system +from starkware.python.utils import safe_zip @dataclasses.dataclass @@ -135,10 +135,19 @@ class Preprocessor(IdentifierAwareVisitor): * Labels. * Constant values. * Functions. + + Arguments: + prime: The prime we are compiling for. + identifiers: An optional initial IdentifierManager. + supported_decorators: A set of decorators that may appear before a function decleration. + functions_to_compile: A set of functions to compile. None means compile everything. """ - def __init__(self, prime: int, identifiers: Optional[IdentifierManager] = None): - super().__init__(identifiers=identifiers) + def __init__( + self, prime: int, identifiers: Optional[IdentifierManager] = None, + supported_decorators: Optional[Set[str]] = None, + functions_to_compile: Optional[Set[ScopedName]] = None): + super().__init__(identifiers=identifiers,) self.prime: int = prime self.instructions: List[PreprocessedInstruction] = [] # Stores the program counter of the next instruction (where the first instruction is at 0). @@ -170,6 +179,27 @@ def __init__(self, prime: int, identifiers: Optional[IdentifierManager] = None): # identifier collector self.scoped_temp_ids: Set[ScopedName] = set() + if supported_decorators is None: + supported_decorators = set() + self.supported_decorators = supported_decorators + + self.functions_to_compile = functions_to_compile + # A set of all scoped prefixes that were not traversed and need to be pruned form the + # identifier manager. + self.removed_prefixes: Set[ScopedName] = set() + + def search_identifier( + self, name: str, location: Optional[Location]) -> Optional[IdentifierDefinition]: + """ + Searches for the given identifier in self.identifiers and returns the corresponding + IdentifierDefinition. + """ + try: + result = self.identifiers.search(self.accessible_scopes, ScopedName.from_string(name)) + return resolve_search_result(result, identifiers=self.identifiers) + except IdentifierError as exc: + raise PreprocessorError(str(exc), location=location) + def handle_missing_future_definition(self, name: ScopedName, location): if name not in self.scoped_temp_ids: super().handle_missing_future_definition(name=name, location=location) @@ -245,6 +275,8 @@ def resolve_labels(self): self.function_metadata = old_function_metadata def get_program(self): + # Prune identifiers. + self.identifiers.prune(self.removed_prefixes) return PreprocessedProgram( prime=self.prime, reference_manager=self.flow_tracking.reference_manager, @@ -288,8 +320,7 @@ def add_references_from_struct_members( and location. """ args = identifier_list.identifiers if identifier_list is not None else [] - assert len(args) == len(members) - for arg, member_def in zip(args, members.values()): + for arg, member_def in safe_zip(args, members.values()): # Add a reference for the argument. assert_no_modifier(arg) self.add_simple_reference( @@ -302,9 +333,16 @@ def visit_CodeElementFunction(self, elm: CodeElementFunction): if elm.element_type == 'struct': return - self.add_label(elm.identifier) + for decorator in elm.decorators: + if decorator.name not in self.supported_decorators: + raise PreprocessorError( + f"Unsupported decorator: '{decorator.name}'.", + location=decorator.location) + self.flow_tracking.revoke() + new_scope = self.current_scope + elm.name + if self.current_scope in self.function_metadata: outer_function_location = self.identifier_locations.get(self.current_scope) notes = [] @@ -319,10 +357,17 @@ def visit_CodeElementFunction(self, elm: CodeElementFunction): notes=notes, ) - new_scope = self.current_scope + elm.name if elm.element_type == 'func': - self.function_metadata[new_scope] = FunctionMetadata( - initial_ap_data=self.flow_tracking.get_ap_tracking()) + # Check if this function should be skipped. + if self.functions_to_compile is not None and new_scope not in self.functions_to_compile: + self.removed_prefixes.add(new_scope) + return + + self.add_function(elm) + else: + assert elm.element_type == 'namespace', f"""\ +Expected 'elm.element_type' to be a 'namespace'. Found: '{elm.element_type}'.""" + self.add_label(identifier=elm.identifier) # Add function arguments and return values and process body. args_scope = new_scope + CodeElementFunction.ARGUMENT_SCOPE @@ -538,15 +583,36 @@ def visit_CodeElementReference(self, elm: CodeElementReference): else: # Copy the type from the value. dst_type = val_type - if not check_cast(src_type=val_type, dest_type=dst_type, cast_type=CastType.ASSIGN): + if not check_cast( + src_type=val_type, + dest_type=dst_type, + identifier_manager=self.identifiers, + cast_type=CastType.ASSIGN): raise PreprocessorError( f"Cannot assign an expression of type '{val_type.format()}' " f"to a reference of type '{dst_type.format()}'.", location=dst_type.location) + location = val.location + + # At this point 'val' is a simplified typeless expression and we need 'ref_expr' + # to include a cast to the correct type. + # We insert the cast at the correct location according to the outermost expression in 'val'. + if isinstance(val, ExprDeref): + # Add the cast inside the ExprDeref. For example, "[cast(ap, T*)]". + addr = get_expr_addr(val) + ref_expr: Expression = ExprDeref( + addr=ExprCast( + expr=addr, + dest_type=TypePointer(pointee=dst_type, location=location), + location=addr.location), + location=location) + else: + ref_expr = ExprCast(expr=val, dest_type=dst_type, location=location) + self.add_reference( name=name, - value=ExprCast(expr=val, dest_type=dst_type), + value=ref_expr, cairo_type=dst_type, location=elm.typed_identifier.location, ) @@ -558,54 +624,32 @@ def visit_CodeElementLocalVariable(self, elm: CodeElementLocalVariable): def visit_CodeElementTemporaryVariable(self, elm: CodeElementTemporaryVariable): assert_no_modifier(elm.typed_identifier) - # Build the instruction: [ap] = elm.expr; ap++. - compound_expressions_code_elements, (expr,), _ = process_compound_expressions( - [self.simplify_expr_as_felt(elm.expr)], [SimplicityLevel.OPERATION], - context=self._compound_expression_context) - for code_element in compound_expressions_code_elements: - self.visit(code_element) - - # Store the hint to avoid the check_no_hints() when invoking the reference element. - hint, self.next_instruction_hint = self.next_instruction_hint, None + expr, src_type = self.simplify_expr(elm.expr) + src_size = self.get_size(src_type) - if elm.typed_identifier.expr_type is not None: - expr_type = elm.typed_identifier.expr_type + if elm.typed_identifier.expr_type is None: + dest_type = src_type else: - # Copy the type from the original expression. - _, expr_type = self.simplify_expr(elm.expr) - if isinstance(expr_type, TypeStruct): - raise PreprocessorError( - "tempvar type annotation must be 'felt' or a pointer.", - location=expr_type.location) - - # Build an expression for [ap]. - deref_ap = ExprCast( - expr=ExprDeref( - addr=ExprReg(reg=Register.AP, location=elm.typed_identifier.identifier.location), - location=elm.typed_identifier.identifier.location), - dest_type=expr_type, - location=elm.typed_identifier.identifier.location) - - # Convert CodeElementTemporaryVariable to two code elements. - # Build the code element: let = [ap]. - self.visit(CodeElementReference( - typed_identifier=elm.typed_identifier, - expr=deref_ap, - )) + dest_type = self.resolve_type(elm.typed_identifier.expr_type) + if not check_cast( + src_type=src_type, dest_type=dest_type, identifier_manager=self.identifiers, + cast_type=CastType.ASSIGN): + raise PreprocessorError( + f"Cannot assign an expression of type '{src_type.format()}' " + f"to a temporary variable of type '{dest_type.format()}'.", + location=dest_type.location) - # Restore the hint. - assert self.next_instruction_hint is None - self.next_instruction_hint = hint + dest_size = self.get_size(dest_type) + assert src_size == dest_size, 'Expecting src and dest types to have the same size.' - self.visit(CodeElementInstruction( - instruction=InstructionAst( - body=AssertEqInstruction( - a=deref_ap, - b=expr, - location=elm.location, - ), - inc_ap=True, - location=elm.location))) + src_exprs = self.simplified_expr_to_felt_expr_list(expr=expr, expr_type=src_type) + self.push_compound_expressions(compound_expressions=src_exprs, location=elm.location) + self.add_simple_reference( + name=self.current_scope + elm.typed_identifier.name, + reg=Register.AP, + cairo_type=dest_type, + offset=-src_size, + location=elm.typed_identifier.identifier.location) def visit_CodeElementCompoundAssertEq(self, instruction: CodeElementCompoundAssertEq): expr_a, expr_type_a = self.simplify_expr(instruction.a) @@ -615,27 +659,30 @@ def visit_CodeElementCompoundAssertEq(self, instruction: CodeElementCompoundAsse f"Cannot compare '{expr_type_a.format()}' and '{expr_type_b.format()}'.", location=instruction.location) - if not isinstance(expr_type_a, (TypeFelt, TypePointer)): - raise PreprocessorError( - f"Expected a 'felt' or a pointer type. Got: '{expr_type_a.format()}'.", - location=instruction.a.location) - - compound_expressions_code_elements, (expr_a, expr_b) = process_compound_assert( - self.simplify_expr_as_felt(instruction.a), - self.simplify_expr_as_felt(instruction.b), - self._compound_expression_context) - assert_eq = CodeElementInstruction( - instruction=InstructionAst( - body=AssertEqInstruction( - a=expr_a, - b=expr_b, - location=instruction.location), - inc_ap=False, - location=instruction.location)) + src_exprs = self.simplified_expr_to_felt_expr_list(expr=expr_a, expr_type=expr_type_a) + dst_exprs = self.simplified_expr_to_felt_expr_list(expr=expr_b, expr_type=expr_type_b) + original_ap_tracking = self.flow_tracking.get_ap_tracking() + + for src, dst in safe_zip(src_exprs, dst_exprs): + ap_diff = self.flow_tracking.get_ap_tracking() - original_ap_tracking + src = self.simplifier.visit(translate_ap(src, ap_diff)) + dst = self.simplifier.visit(translate_ap(dst, ap_diff)) + compound_expressions_code_elements, (expr_a, expr_b) = process_compound_assert( + src, + dst, + self._compound_expression_context) + assert_eq = CodeElementInstruction( + instruction=InstructionAst( + body=AssertEqInstruction( + a=expr_a, + b=expr_b, + location=instruction.location), + inc_ap=False, + location=instruction.location)) - for code_element in compound_expressions_code_elements: - self.visit(code_element) - self.visit(assert_eq) + for code_element in compound_expressions_code_elements: + self.visit(code_element) + self.visit(assert_eq) def visit_CodeElementStaticAssert(self, elm: CodeElementStaticAssert): a = self.simplify_expr_as_felt(elm.a) @@ -651,7 +698,7 @@ def optimize_expressions_for_push(self, exprs: List[Expression]) -> List[Express Example: If we need to push [ap - 2], [ap - 1], [fp] + 3, there is no need to push the first 2 - expressions, since they are already at the top of te stack. + expressions, since they are already at the top of the stack. """ if len(exprs) == 0: @@ -801,8 +848,7 @@ def process_implicit_arguments( implicit_args = [None] * len(implicit_args_struct.members) compound_expressions = [] - assert len(implicit_args_struct.members) == len(implicit_args) - for (member_name, member_def), implicit_arg in zip( + for (member_name, member_def), implicit_arg in safe_zip( implicit_args_struct.members.items(), implicit_args): expr: Expression if implicit_arg is not None: @@ -916,7 +962,9 @@ def check_tail_call_cast(self, src_type: CairoType, dest_type: CairoType) -> boo """ Checks if src_type can be converted to dest_type in the context of a tail call. """ - if check_cast(src_type=src_type, dest_type=dest_type, cast_type=CastType.ASSIGN): + if check_cast( + src_type=src_type, dest_type=dest_type, identifier_manager=self.identifiers, + cast_type=CastType.ASSIGN): return True if not isinstance(src_type, TypeStruct) or not isinstance(dest_type, TypeStruct): @@ -934,6 +982,7 @@ def check_tail_call_cast(self, src_type: CairoType, dest_type: CairoType) -> boo if not check_cast( src_type=src_member.cairo_type, dest_type=dest_member.cairo_type, + identifier_manager=self.identifiers, cast_type=CastType.ASSIGN): return False @@ -1028,7 +1077,24 @@ def add_implicit_return_references( location=implicit_arg_location) def visit_CodeElementFuncCall(self, elm: CodeElementFuncCall): + # Make sure the identifier for the called function refers to a function. called_function = ScopedName.from_string(elm.func_call.func_ident.name) + try: + res = self.identifiers.search( + accessible_scopes=self.accessible_scopes, name=called_function) + res.assert_fully_parsed() + except IdentifierError as exc: + raise PreprocessorError(str(exc), location=elm.func_call.func_ident.location) + called_function_def = res.identifier_definition + called_function_def_type = called_function_def.identifier_type \ + if isinstance(called_function_def, FutureIdentifierDefinition) \ + else type(called_function_def) + if called_function_def_type is not FunctionDefinition: + raise PreprocessorError( + f'Expected {called_function} to be a function name. ' + f'Found: {called_function_def.TYPE}.', + location=elm.func_call.func_ident.location) + implicit_args_struct_name = called_function + CodeElementFunction.IMPLICIT_ARGUMENT_SCOPE implicit_args = ( cast(List[ExprAssignment], elm.func_call.implicit_arguments.args) @@ -1094,6 +1160,15 @@ def visit_CodeElementReturnValueReference(self, elm: CodeElementReturnValueRefer if isinstance(elm.func_call.call_inst, CallLabelInstruction): func_ident = elm.func_call.call_inst.label elif isinstance(elm.func_call, RvalueFuncCall): + # If the function name is the name of a struct, replace the + # CodeElementReturnValueReference with a regular reference. + if self.try_get_struct_definition( + ScopedName.from_string(elm.func_call.func_ident.name)) is not None: + return self.visit(CodeElementReference( + typed_identifier=elm.typed_identifier, + expr=ExprFuncCall( + rvalue=elm.func_call, + location=elm.func_call.location))) call_elm = CodeElementFuncCall(func_call=elm.func_call) func_ident = elm.func_call.func_ident else: @@ -1171,6 +1246,7 @@ def visit_CodeElementUnpackBinding(self, elm: CodeElementUnpackBinding): if not check_cast( src_type=member_def.cairo_type, dest_type=cairo_type, + identifier_manager=self.identifiers, cast_type=CastType.UNPACKING): raise PreprocessorError( f"""\ @@ -1193,7 +1269,7 @@ def add_label(self, identifier: ExprIdentifier): def add_reference( self, name: ScopedName, value: Expression, cairo_type: CairoType, - location: Optional[Location]): + location: Optional[Location], require_future_definition=True): if name.path[-1] == '_': raise PreprocessorError("Reference name cannot be '_'.", location=location) @@ -1219,7 +1295,20 @@ def add_reference( self.add_name_definition( name, ReferenceDefinition(full_name=name, cairo_type=cairo_type, references=[reference]), - location=location) + location=location, require_future_definition=require_future_definition) + + def add_function(self, elm: CodeElementFunction): + name = self.current_scope + elm.name + self.add_name_definition( + name, + FunctionDefinition( # type: ignore + pc=self.current_pc, + decorators=[identifier.name for identifier in elm.decorators], + ), + location=elm.identifier.location) + + self.function_metadata[name] = FunctionMetadata( + initial_ap_data=self.flow_tracking.get_ap_tracking()) def visit_CodeElementLabel(self, elm: CodeElementLabel): self.check_no_hints('Hints before labels are not allowed.') @@ -1367,8 +1456,25 @@ def visit_BuiltinsDirective(self, directive: BuiltinsDirective): 'Redefinition of builtins directive.', location=directive.location, ) + + seen_builtins = set() + for builtin in directive.builtins: + if builtin in seen_builtins: + raise PreprocessorError( + f"The builtin '{builtin}' appears twice in builtins directive.", + location=directive.location, + ) + + seen_builtins.add(builtin) + self.builtins = directive.builtins + def visit_LangDirective(self, directive: LangDirective): + raise PreprocessorError( + f'Unsupported %lang directive. Are you using the correct compiler?', + location=directive.location, + ) + def simplify_expr(self, expr) -> Tuple[Expression, CairoType]: """ Simplifies the expression by resolving identifiers, type-system related reductions @@ -1379,7 +1485,7 @@ def simplify_expr(self, expr) -> Tuple[Expression, CairoType]: expr=expr, get_identifier_callback=self.get_variable, resolve_type_callback=self.resolve_type) - expr, expr_type = simplify_type_system(expr) + expr, expr_type = simplify_type_system(expr, identifiers=self.identifiers) return self.simplifier.visit(expr), self.resolve_type(expr_type) def simplify_expr_as_felt(self, expr) -> Expression: @@ -1406,47 +1512,69 @@ def simplify_expr_to_felt_expr_list( location = expr.location expr, expr_type = self.simplify_expr(expr) - if not check_cast(src_type=expr_type, dest_type=expected_type, cast_type=CastType.ASSIGN): + if not check_cast( + src_type=expr_type, dest_type=expected_type, identifier_manager=self.identifiers, + cast_type=CastType.ASSIGN): raise PreprocessorError( f"""\ Expected expression of type '{expected_type.format()}', got '{expr_type.format()}'.""", location=location ) - if isinstance(expr_type, (TypeFelt, TypePointer)): - return [expr] + return self.simplified_expr_to_felt_expr_list(expr=expr, expr_type=expr_type) - assert isinstance(expr_type, TypeStruct), f'Unexpected type {expr_type}.' + def simplified_expr_to_felt_expr_list( + self, expr: Expression, expr_type: CairoType) -> List[Expression]: + """ + Takes a simplified expression and its type and splits it into a list of typeless expressions + that can be passed to process_compound_expressions. + """ - struct_members = get_struct_definition( - expr_type.scope, identifier_manager=self.identifiers).members - addr = get_expr_addr(expr) - exprs: List[Expression] = [] - for offset, member_def in enumerate(struct_members.values()): - if not isinstance(member_def.cairo_type, (TypeFelt, TypePointer)): - raise PreprocessorError( - 'Nested structs are not supported.', - location=location) + if isinstance(expr_type, (TypeFelt, TypePointer)): + return [expr] - if offset != member_def.offset: - raise PreprocessorError( - 'Discontinuous structs are not supported.', - location=location) + # Get the list of member types. + if isinstance(expr_type, TypeTuple): + member_types = expr_type.members + elif isinstance(expr_type, TypeStruct): + struct_definition = get_struct_definition( + expr_type.scope, identifier_manager=self.identifiers) + member_types = [ + member_def.cairo_type for member_def in struct_definition.members.values()] + else: + raise PreprocessorError(f'Unexpected type {expr_type}.', location=expr_type.location) - # Call simplifier to convert (fp + offset_1) + offset_2 to fp + (offset_1 + offset_2). - exprs.append( - self.simplifier.visit( - ExprDeref( - ExprOperator( - a=addr, - op='+', - b=ExprConst(member_def.offset, location=location), + # Get the list of member expressions. + if isinstance(expr, ExprTuple): + member_exprs = [assign_expr.expr for assign_expr in expr.members.args] + else: + addr = get_expr_addr(expr) + + offset = 0 + member_exprs = [] + location = expr.location + for member_type in member_types: + # Call simplifier to convert (fp + offset_1) + offset_2 to + # fp + (offset_1 + offset_2). + member_exprs.append( + self.simplifier.visit( + ExprDeref( + ExprOperator( + a=addr, + op='+', + b=ExprConst(offset, location=location), + location=location, + ), location=location, - ), - location=location, - ))) + ))) - return exprs + offset += self.get_size(member_type) + + expr_list = [] + for member_expr, member_type in zip(member_exprs, member_types): + expr_list.extend(self.simplified_expr_to_felt_expr_list( + expr=member_expr, expr_type=member_type)) + return expr_list def get_label(self, label_name: str, location: Optional[Location]) -> \ Tuple[Optional[int], Optional[ScopedName]]: @@ -1559,53 +1687,6 @@ def new_unique_id(self) -> str: return name -def preprocess_codes( - codes: Sequence[Tuple[str, str]], prime: int, - read_module: Callable[[str], Tuple[str, str]], - main_scope: ScopedName = ScopedName(), - preprocessor_cls: Optional[Type[Preprocessor]] = None) -> PreprocessedProgram: - """ - Preprocesses a list of Cairo file and returns a PreprocessedProgram instance. - codes is a list of pairs (code_string, file_name). - read_module is a callback that gets a module name ('a.b.c') and returns a pair - (file content, file name) - """ - modules = [] - for code, filename in codes: - # Function used to read files given module names. - # The root module (filename) is handled separately, for this module code is returned. - def read_file_fixed(name): - return (code, filename) if name == filename else read_module(name) - - files = collect_imports(filename, read_file=read_file_fixed) - for module_name, ast in files.items(): - # Preprocess files explicitly provided in the root scope. - scope = main_scope if module_name == filename else ScopedName.from_string(module_name) - modules.append(CairoModule(cairo_file=ast, module_name=scope)) - - unique_label_creator = UniqueLabelCreator() - modules = list(map(unique_label_creator.visit, modules)) - - identifier_collector = IdentifierCollector() - for module in modules: - identifier_collector.visit(module) - - struct_collector = StructCollector(identifiers=identifier_collector.identifiers) - for module in modules: - struct_collector.visit(module) - - if preprocessor_cls is None: - preprocessor_cls = Preprocessor - - preprocessor = preprocessor_cls(prime=prime, identifiers=struct_collector.identifiers) - preprocessor.identifier_locations = struct_collector.identifier_locations - for module in modules: - preprocessor.visit(module) - - preprocessor.resolve_labels() - return preprocessor.get_program() - - class PreprocessorCompoundExpressionContext(CompoundExpressionContext): def __init__(self, preprocessor: Preprocessor): self.preprocessor = preprocessor diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py index de41bc00..c4c9bc95 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py @@ -8,13 +8,15 @@ from starkware.cairo.lang.compiler.identifier_manager import IdentifierError from starkware.cairo.lang.compiler.instruction_builder import InstructionBuilderError from starkware.cairo.lang.compiler.parser import parse_type -from starkware.cairo.lang.compiler.preprocessor.preprocessor import preprocess_codes +from starkware.cairo.lang.compiler.preprocessor.default_pass_manager import default_pass_manager +from starkware.cairo.lang.compiler.preprocessor.preprocess_codes import preprocess_codes from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import ( - PRIME, TEST_SCOPE, preprocess_str, verify_exception) + PRIME, TEST_SCOPE, preprocess_str, strip_comments_and_linebreaks, verify_exception) from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.test_utils import read_file_from_dict -from starkware.cairo.lang.compiler.type_system_visitor import ( - mark_type_resolved, simplify_type_system) +from starkware.cairo.lang.compiler.type_casts import CairoTypeError +from starkware.cairo.lang.compiler.type_system import mark_type_resolved +from starkware.cairo.lang.compiler.type_system_visitor import simplify_type_system def test_compiler(): @@ -122,6 +124,10 @@ def test_temporary_variable(): tempvar y : T* = cast(x, T*) ap += 4 [fp] = y.t +ap += 5 +tempvar z : (felt, felt) = (1, 2) +# Check the expression pushing optimization. +tempvar z : (felt, felt) = ([ap - 1], 3) """ program = preprocess_str(code=code, prime=PRIME) assert program.format() == """\ @@ -130,19 +136,20 @@ def test_temporary_variable(): [ap] = [ap + (-4)]; ap++ ap += 4 [fp] = [[ap + (-5)] + 1] +ap += 5 +[ap] = 1; ap++ +[ap] = 2; ap++ +[ap] = 3; ap++ """ def test_temporary_variable_failures(): verify_exception(""" -struct T: - member t : felt -end -tempvar x : T = 0 +tempvar x : felt = cast([ap], felt*) """, """ -file:?:?: tempvar type annotation must be 'felt' or a pointer. -tempvar x : T = 0 - ^ +file:?:?: Cannot assign an expression of type 'felt*' to a temporary variable of type 'felt'. +tempvar x : felt = cast([ap], felt*) + ^**^ """) verify_exception(""" tempvar _ = 0 @@ -412,7 +419,7 @@ def test_tail_call_failure(): """, """ file:?:?: Unknown identifier 'g'. return g() - ^*^ + ^ """) verify_exception(""" @@ -480,13 +487,13 @@ def test_func_args(): program = preprocess_str(code=code, prime=PRIME, main_scope=scope) reference_x = program.instructions[-1].flow_tracking_data.resolve_reference( reference_manager=program.reference_manager, name=scope + 'f.x') - assert reference_x.value.format() == 'cast([fp + (-6)], felt)' + assert reference_x.value.format() == '[cast(fp + (-6), felt*)]' reference_y = program.instructions[-1].flow_tracking_data.resolve_reference( reference_manager=program.reference_manager, name=scope + 'f.y') - assert reference_y.value.format() == f'cast([fp + (-5)], {scope}.T)' + assert reference_y.value.format() == f'[cast(fp + (-5), {scope}.T*)]' reference_z = program.instructions[-1].flow_tracking_data.resolve_reference( reference_manager=program.reference_manager, name=scope + 'f.z') - assert reference_z.value.format() == f'cast([fp + (-3)], {scope}.T*)' + assert reference_z.value.format() == f'[cast(fp + (-3), {scope}.T**)]' assert program.format() == """\ [fp + (-6)] = 1; ap++ [fp + (-5)] = 2; ap++ @@ -601,7 +608,7 @@ def test_implicit_args(): func f{x: T}() -> (): # Rebind x. - let x = cast([fp - 1234], T) + let x = [cast(fp - 1234, T*)] return () end @@ -620,7 +627,7 @@ def test_implicit_args(): func h(): let y = 10 - let x: T = cast([fp - 100], T) + let x: T = [cast(fp - 100, T*)] with x, y: let (res2) = g(0, 0) end @@ -909,12 +916,17 @@ def test_func_named_args_failures(): def test_function_call_by_value_args(): code = """\ +struct S: + member a : felt + member b : felt +end + struct T: member s : felt - member t : felt + member t : S end func f(x, y : T, z : T): - let t : T = cast([ap], T) + let t : T = [cast(ap, T*)] let res = f(x=2, y=z, z=t) return() end @@ -922,11 +934,13 @@ def test_function_call_by_value_args(): program = preprocess_str(code=code, prime=PRIME) assert program.format() == """\ [ap] = 2; ap++ +[ap] = [fp + (-5)]; ap++ [ap] = [fp + (-4)]; ap++ [ap] = [fp + (-3)]; ap++ -[ap] = [ap + (-3)]; ap++ -[ap] = [ap + (-3)]; ap++ -call rel -6 +[ap] = [ap + (-4)]; ap++ +[ap] = [ap + (-4)]; ap++ +[ap] = [ap + (-4)]; ap++ +call rel -8 ret """ @@ -959,27 +973,6 @@ def test_func_by_value_args_failures(test_line, expected_type, actual_type, arro """, main_scope=ScopedName()) -def test_func_by_value_nested_struct_failures(): - verify_exception(""" -struct S: - member a : felt -end - -struct T: - member s : felt - member t : S -end -func f(x, y : T): - f(1, y=y) - ret -end -""", """ -file:?:?: Nested structs are not supported. - f(1, y=y) - ^ -""") - - def test_func_by_value_return(): code = """\ struct T: @@ -987,7 +980,7 @@ def test_func_by_value_return(): member t : felt end func f(s : T) -> (x : T, y : T): - let t : T = cast([ap - 100], T) + let t : T = [cast(ap - 100, T*)] return(x=s, y=t) end """ @@ -1082,7 +1075,8 @@ def test_import(): """ } program = preprocess_codes( - codes=[(files['.'], '.')], prime=PRIME, read_module=read_file_from_dict(files)) + codes=[(files['.'], '.')], + pass_manager=default_pass_manager(prime=PRIME, read_module=read_file_from_dict(files))) assert program.format() == """\ jmp rel 0 @@ -1123,7 +1117,8 @@ def get_full_name(name, curr_scope=''): # Preprocess program. program = preprocess_codes( - codes=[(files['.'], '.')], prime=PRIME, read_module=read_file_from_dict(files), + codes=[(files['.'], '.')], + pass_manager=default_pass_manager(prime=PRIME, read_module=read_file_from_dict(files)), main_scope=scope('__main__')) # Verify identifiers are resolved correctly. @@ -1186,7 +1181,7 @@ def test_import_errors(): """, files={'foo': 'const bar=0'}) verify_exception('from foo import bar', """ \ -file:?:?: Scope 'foo' does not include identifier 'bar'. +file:?:?: Cannot import 'bar' from 'foo'. from foo import bar ^*^ """, files={'foo': ''}) @@ -1377,8 +1372,8 @@ def test_redefinition_failures(): """) verify_exception(""" func f() -> (name, x, name): - [ap + name] = 1 - [ap + x] = 2 + [ap] = 1 + [ap] = 2 end """, """ file:?:?: Redefinition of 'test_scope.f.Return.name'. @@ -1412,6 +1407,13 @@ def test_directives_failures(): file:?:?: Directives must appear at the top of the file. %builtins ab cd ef ^****************^ +""") + verify_exception(""" +%lang abc +""", """ +file:?:?: Unsupported %lang directive. Are you using the correct compiler? +%lang abc +^*******^ """) @@ -1634,6 +1636,16 @@ def test_builtins_failures(): """) +def test_builtin_directive_duplicate_entry(): + verify_exception(""" +%builtins pedersen ecdsa pedersen +""", """ +file:?:?: The builtin 'pedersen' appears twice in builtins directive. +%builtins pedersen ecdsa pedersen +^*******************************^ +""") + + def test_references(): program = preprocess_str(code=""" call label1 @@ -1883,7 +1895,7 @@ def test_references_revoked_multiple_location(): def test_references_failures(): - verify_exception(f""" + verify_exception(""" let ref = [fp] let ref2 = ref [ref2] = [[fp]] @@ -2019,7 +2031,7 @@ def get_reference(name): reference = get_reference('main.y') assert simplify_type_system(reference.value)[1] == expected_type_y - assert reference.value.format() == f'cast([ap + 10], {scope}.main.Struct)' + assert reference.value.format() == f'[cast(ap + 10, {scope}.main.Struct*)]' assert program.format() == """\ [fp] = [ap + 12] [fp] = [[ap + 12] + 3] @@ -2035,10 +2047,10 @@ def test_typed_references_failures(): let x = fp x.a = x.a """, """ -file:?:?: Member access requires a type of the form Struct*. +file:?:?: Cannot apply dot-operator to non-struct type 'felt'. x.a = x.a ^*^ -""") +""", exc_type=CairoTypeError) verify_exception(f""" struct T: member z : felt @@ -2056,10 +2068,10 @@ def test_typed_references_failures(): member z : felt end -let x : T* = cast([ap], T) +let x : T* = [cast(ap, T*)] """, """ file:?:?: Cannot assign an expression of type 'test_scope.T' to a reference of type 'test_scope.T*'. -let x : T* = cast([ap], T) +let x : T* = [cast(ap, T*)] ^^ """) @@ -2127,10 +2139,10 @@ def test_return_value_reference_failures(): let x = call foo [x.a] = 0 """, """ -file:?:?: 'a' is not a member of 'test_scope.foo.Return'. +file:?:?: Member 'a' does not appear in definition of struct 'test_scope.foo.Return'. [x.a] = 0 ^*^ -""") +""", exc_type=CairoTypeError) verify_exception(f""" func foo(): ret @@ -2148,10 +2160,10 @@ def test_return_value_reference_failures(): let x : T* = cast(ap, T*) [ap] = x.a """, """ -file:?:?: 'a' is not a member of 'test_scope.T'. +file:?:?: Member 'a' does not appear in definition of struct 'test_scope.T'. [ap] = x.a ^*^ -""") +""", exc_type=CairoTypeError) def test_unpacking(): @@ -2255,18 +2267,23 @@ def test_unpacking_failures(): member a : felt member b : felt end +struct S: + member a : felt + member b : felt +end func foo() -> (a, b : T): ret end func test(): alloc_locals - let (a, local b : T) = foo() + let (a, local b : S) = foo() ret end """, """ -file:?:?: Expected a 'felt' or a pointer type. Got: 'test_scope.T'. - let (a, local b : T) = foo() - ^ +file:?:?: Expected expression of type 'test_scope.T', got 'test_scope.S'. + let (a, local b : S) = foo() + ^*********^ + """) verify_exception(f""" @@ -2368,7 +2385,7 @@ def test_bad_type_annotation(): ret end """, """ -file:?:?: Expected 'test_scope.foo' to be a struct. Found: 'label'. +file:?:?: Expected 'test_scope.foo' to be a struct. Found: 'function'. local a : foo ^*^ """) @@ -2382,7 +2399,7 @@ def test_bad_type_annotation(): ret end """, """ -file:?:?: Expected 'foo' to be a struct. Found: 'label'. +file:?:?: Expected 'foo' to be a struct. Found: 'function'. member a : foo* ^*^ """) @@ -2402,13 +2419,29 @@ def test_bad_type_annotation(): """) +def test_cast_failure(): + verify_exception(""" +struct A: +end + +func foo(a : A*): + let a = cast(5, A) + return () +end +""", """ +file:?:?: Cannot cast 'felt' to 'test_scope.A'. + let a = cast(5, A) + ^********^ +""", exc_type=CairoTypeError) + + def test_nested_function_failure(): verify_exception(""" func foo(): func bar(): return() end - return() + return () end """, """ file:?:?: Nested functions are not supported. @@ -2425,7 +2458,7 @@ def test_namespace_inside_function_failure(): func foo(): namespace MyNamespace: end - return() + return () end @@ -2439,22 +2472,525 @@ def test_namespace_inside_function_failure(): """) -def test_tuple_failures(): +def test_struct_assignments(): + struct_def = """\ +struct B: + member a : felt + member b : felt +end + +struct T: + member a : B + member b : felt +end +""" + + code = f"""\ +{struct_def} +func f(t : T*): + alloc_locals + local a : T = [t] + return () +end +""" + program = preprocess_str(code=code, prime=PRIME) + assert program.format() == """\ +ap += 3 +[fp] = [[fp + (-3)]] +[fp + 1] = [[fp + (-3)] + 1] +[fp + 2] = [[fp + (-3)] + 2] +ret +""" + + code = f"""\ +{struct_def} +func copy(src : T**, dest: T**): + assert [[dest]] = [[src]] + return () +end +""" + program = preprocess_str(code=code, prime=PRIME) + assert program.format() == """\ +[ap] = [[fp + (-3)]]; ap++ +[ap] = [[fp + (-4)]]; ap++ +[ap] = [[ap + (-1)]]; ap++ +[[ap + (-3)]] = [ap + (-1)] +[ap] = [[fp + (-3)]]; ap++ +[ap] = [[fp + (-4)]]; ap++ +[ap] = [[ap + (-1)] + 1]; ap++ +[[ap + (-3)] + 1] = [ap + (-1)] +[ap] = [[fp + (-3)]]; ap++ +[ap] = [[fp + (-4)]]; ap++ +[ap] = [[ap + (-1)] + 2]; ap++ +[[ap + (-3)] + 2] = [ap + (-1)] +ret +""" + + +def test_subscript_operator(): + code = """\ +struct T: + member x: felt + member y: felt +end + +struct S: + member a : T + member b : T + member c : T +end + +func f(s_arr : S*, table : felt**, perm : felt*): + assert s_arr[0].b.x = s_arr[1].a.y + assert (&s_arr[0].a)[2].x = (&s_arr[1].b.y)[-2] + + assert table[1][2] = 11 + + assert perm[0] = 1 + assert perm[1] = 0 + assert perm[perm[0]] = 0 + + tempvar i = 2 + tempvar j = 5 + tempvar k = -13 + assert (&(&s_arr[i].b)[j].x)[k] = s_arr[1].c.y + assert table[i][j] = 17 + + return() +end +""" + program = preprocess_str(code=code, prime=PRIME) + expected_result = """\ +[ap] = [[fp + (-5)] + 7]; ap++ # push s_arr[1].a.y +[[fp + (-5)] + 2] = [ap + (-1)] # assert s_arr[0].b.x = s_arr[1].a.y + +[ap] = [[fp + (-5)] + 7]; ap++ # push (&s_arr[1].b.y)[-2] +[[fp + (-5)] + 4] = [ap + (-1)] # assert (&s_arr[0].a)[2].x = (&s_arr[1].b.y)[-2] + + +[ap] = [[fp + (-4)] + 1]; ap++ # push table[1] +[ap] = 11; ap++ # push 11 +[[ap + (-2)] + 2] = [ap + (-1)] # assert table[1][2] = 11 + + +[ap] = 1; ap++ # push 1 +[[fp + (-3)]] = [ap + (-1)] # assert perm[0] = 1 + +[ap] = 0; ap++ # push 0 +[[fp + (-3)] + 1] = [ap + (-1)] # assert perm[1] = 0 + +[ap] = [[fp + (-3)]]; ap++ # push perm[0] +[ap] = [fp + (-3)] + [ap + (-1)]; ap++ # push perm + perm[0] +[ap] = 0; ap++ # push 0 +[[ap + (-2)]] = [ap + (-1)] # assert perm[perm[0]] = 0 + + +[ap] = 2; ap++ # tempvar i = 2 +[ap] = 5; ap++ # tempvar j = 5 +[ap] = -13; ap++ # tempvar k = -13 + +[ap] = [ap + (-3)] * 6; ap++ # push i * 6 +[ap] = [ap + (-1)] + 2; ap++ # push i * 6 + 2 +[ap] = [fp + (-5)] + [ap + (-1)]; ap++ # push &s_arr[i].b ( = s_arr + i * 6 + 2) +[ap] = [ap + (-5)] * 2; ap++ # push j * 2 +[ap] = [ap + (-2)] + [ap + (-1)]; ap++ # push &(&s_arr[i].b)[j].x +[ap] = [ap + (-1)] + [ap + (-6)]; ap++ # push &(&s_arr[i].b)[j].x + k +[ap] = [[fp + (-5)] + 11]; ap++ # push s_arr[1].b.y +[[ap + (-2)]] = [ap + (-1)] # assert (&(&s_arr[i].a)[j].x)[k] = s_arr[1].b.y + +[ap] = [fp + (-4)] + [ap + (-10)]; ap++ # push table + i +[ap] = [[ap + (-1)]]; ap++ # push table[i] +[ap] = [ap + (-1)] + [ap + (-11)]; ap++ # push table[i] + j +[ap] = 17; ap++ # push 17 +[[ap + (-2)]] = [ap + (-1)] # assert table[i][j] = 17 +ret +""" + assert program.format() == strip_comments_and_linebreaks(expected_result) + + +def test_dot_operator(): + code = """\ +struct R: + member x: felt + member r : R* +end + +struct S: + member x : felt + member y : felt +end + +struct T: + member x : felt + member s : S + member sp : S* +end + +func f(): + alloc_locals + let __fp__ = [fp - 100] + + local s : S + local s2 : S + local t : T + local r1 : R + + s.x = 14 + (s).y = 2 + (&t).x = 7 + assert t.s = s + + ((t).s).x = t.x * 2 + assert t.s = (t).s + assert (t.s).x = t.s.x + assert (&(t.s)).y = ((t).s).y + + assert t.sp = &s + assert t.sp.x = 14 + assert [t.sp].y = 2 + assert [t.sp] = s + assert [t.sp] = (&t).s + assert &((t).s) = t.sp + 5 + + assert t.sp + 2 = &s2 + assert [t.sp + 2].x = s.x + assert (t.sp + 2).y = s.y + + assert [r1.r.r].r.r.r.r = &r1 + + return() +end +""" + program = preprocess_str(code=code, prime=PRIME) + expected_result = """\ +ap += 10 # alloc_locals +[fp] = 14 # s.x = 14 +[fp + 1] = 2 # (s).y = 2 +[fp + 4] = 7 # (&t).x = 7 +[fp + 5] = [fp] # assert t.s = s (x member) +[fp + 6] = [fp + 1] # assert t.s = s (y member) + +[fp + 5] = [fp + 4] * 2 # ((t).s).x = t.x * 2 +[fp + 5] = [fp + 5] # assert t.s = (t).s (x member) +[fp + 6] = [fp + 6] # assert t.s = (t).s (y member) +[fp + 5] = [fp + 5] # assert (t.s).x = t.s.x +[fp + 6] = [fp + 6] # assert (&(t.s)).y = ((t).s).y + +[fp + 7] = [fp + (-100)] # assert t.sp = &s +[ap] = 14; ap++ # push 14 +[[fp + 7]] = [ap + (-1)] # assert t.sp.x = 14 +[ap] = 2; ap++ # push 2 +[[fp + 7] + 1] = [ap + (-1)] # assert [t.sp].y = 2 +[[fp + 7]] = [fp] # assert [t.sp] = s (x member) +[[fp + 7] + 1] = [fp + 1] # assert [t.sp] = s (y member) +[[fp + 7]] = [fp + 5] # assert [t.sp] = (&t).s (x member) +[[fp + 7] + 1] = [fp + 6] # assert [t.sp] = (&t).s (y member) +[ap] = [fp + 7] + 5; ap++ # push t.sp + 5 +[fp + (-100)] + 5 = [ap + (-1)] # assert &(t.s) = t.sp + 5 + +[ap] = [fp + (-100)] + 2; ap++ # push &s2 +[fp + 7] + 2 = [ap + (-1)] # assert t.sp + 2 = &s2 +[[fp + 7] + 2] = [fp] # assert [t.sp + 2].x = s.x +[[fp + 7] + 3] = [fp + 1] # assert (t.sp + 2).y = s.y + + # assert [r1.r.r].r.r.r.r = &r1 : +[ap] = [[fp + 9] + 1]; ap++ # push (r1.r).r ([fp + 9] = r1.r) +[ap] = [[ap + (-1)] + 1]; ap++ # push (r1.r.r).r +[ap] = [[ap + (-1)] + 1]; ap++ # push (r1.r.r.r).r +[ap] = [[ap + (-1)] + 1]; ap++ # push (r1.r.r.r.r).r +[ap] = [fp + (-100)] + 8; ap++ # push &r1 +[[ap + (-2)] + 1] = [ap + (-1)] # assert (r1.r.r.r.r.r).r = &r1 +ret +""" + assert program.format() == strip_comments_and_linebreaks(expected_result) + + +def test_tuple_assertions(): + code = f"""\ +func f(): + alloc_locals + local var : (felt, felt) = [cast(ap, (felt, felt)*)] + return () +end +""" + program = preprocess_str(code=code, prime=PRIME) + assert program.format() == """\ +ap += 2 +[fp] = [ap] +[fp + 1] = [ap + 1] +ret +""" + + +def test_tuple_expression(): + code = """\ +struct A: + member x : felt + member y : felt* +end +struct B: + member x : felt + member y : A + member z : A* +end +func foo(a : A*): + alloc_locals + let a : A* = cast([fp], A*) + local b : B = cast((1, [a], a), B) + + assert (b.x, b.z, a) = (5, a, a) + return () +end +""" + program = preprocess_str(code=code, prime=PRIME) + assert program.format() == """\ +ap += 4 +[fp] = 1 +[fp + 1] = [[fp]] +[fp + 2] = [[fp] + 1] +[fp + 3] = [fp] +[fp] = 5 +[fp + 3] = [fp] +[fp] = [fp] +ret +""" + + +def test_tuple_expression_failures(): + verify_exception(""" +struct A: + member x : felt +end +struct B: +end +let a = cast(fp, A*) +let b = cast((1, a), B) +""", """ +file:?:?: Cannot cast an expression of type '(felt, test_scope.A*)' to 'test_scope.B'. +The former has 2 members while the latter has 0 members. +let b = cast((1, a), B) + ^****^ +""", exc_type=CairoTypeError) + + verify_exception(""" +struct A: + member x : felt +end +struct B: + member a : felt + member b : felt +end +let a = cast(fp, A*) +let b = cast((a, 1), B) +""", """ +file:?:?: Cannot cast 'test_scope.A*' to 'felt'. +let b = cast((a, 1), B) + ^ +""", exc_type=CairoTypeError) + verify_exception(""" -func foo(x : (felt, felt)): +struct B: + member a : felt + member b : felt end +let b = cast([cast(ap, (felt, felt*)*)], B) """, """ -file:?:?: Tuples are not supported yet. -func foo(x : (felt, felt)): - ^**********^ +file:?:?: Cannot cast 'felt*' to 'felt'. +let b = cast([cast(ap, (felt, felt*)*)], B) + ^************************^ +""", exc_type=CairoTypeError) + + verify_exception(""" +struct B: +end +let b = cast([cast(ap, (felt, felt*)*)], B) +""", """ +file:?:?: Cannot cast an expression of type '(felt, felt*)' to 'test_scope.B'. +The former has 2 members while the latter has 0 members. +let b = cast([cast(ap, (felt, felt*)*)], B) + ^************************^ +""", exc_type=CairoTypeError) + verify_exception(""" +(1, 1) = 1 +""", """ +file:?:?: Expected a 'felt' or a pointer type. Got: '(felt, felt)'. +(1, 1) = 1 +^****^ """) + + verify_exception(""" +assert (1, 1) = 1 +""", """ +file:?:?: Cannot compare '(felt, felt)' and 'felt'. +assert (1, 1) = 1 +^***************^ +""") + + +def test_struct_constructor(): + code = """\ +struct A: + member x : felt + member y : felt +end +struct B: + member x : felt + member y : A + member z : A + member w : A* +end +func foo(a_ptr : A*): + alloc_locals + local b1 : B = B(x=0, y=A(1, 2), z=[a_ptr], w=a_ptr) + let a = A(x=a_ptr.x, y=0) + assert a = A(x=1, y=2) + + tempvar y: felt* = cast(1, felt*) + tempvar x: A* = cast(0, A*) + assert [x] = A(x=[y], y=[y]) + return () +end +""" + program = preprocess_str(code=code, prime=PRIME) + expected_result = """\ +ap += 6 +# Populate b1. +[fp] = 0 +[fp + 1] = 1 +[fp + 2] = 2 +[fp + 3] = [[fp + (-3)]] +[fp + 4] = [[fp + (-3)] + 1] +[fp + 5] = [fp + (-3)] + +# assert a = A(x=1, y=2) (x component). +[ap] = 1; ap++ +[[fp + (-3)]] = [ap + (-1)] + +# assert a = A(x=1, y=2) (y component). +[ap] = 2; ap++ +0 = [ap + (-1)] + +# tempvar y: felt* = cast(1, felt*). +[ap] = 1; ap++ +# tempvar x: A* = cast(0, A*). +[ap] = 0; ap++ +# assert [x] = A(x=[y], y=[y]). +[ap] = [[ap + (-2)]]; ap++ +[[ap + (-2)]] = [ap + (-1)] +[ap] = [[ap + (-3)]]; ap++ +[[ap + (-3)] + 1] = [ap + (-1)] +ret +""" + assert program.format() == strip_comments_and_linebreaks(expected_result) + + +def test_struct_constructor_failures(): verify_exception(""" func foo(): + ret +end + +foo(3) = foo(4) +""", """ +file:?:?: Expected 'foo' to be a struct. Found: 'function'. +foo(3) = foo(4) +^****^ +""") + + def verify_exception_for_expr(expr_str: str, expected_error: str): + verify_exception(f""" +struct T: + member x : felt + member y : felt +end + +func foo(a): alloc_locals - local x : (felt, felt) + local a : T = {expr_str} +end +""", expected_error, exc_type=CairoTypeError) + + verify_exception_for_expr('T(5, 6, 7)', """ +file:?:?: Cannot cast an expression of type '(felt, felt, felt)' to 'test_scope.T'. +The former has 3 members while the latter has 2 members. + local a : T = T(5, 6, 7) + ^********^ +""") + + verify_exception_for_expr('&T(5, 6)', """ +file:?:?: Expression has no address. + local a : T = &T(5, 6) + ^*****^ +""") + + verify_exception_for_expr('T(5, 6).x', """ +file:?:?: Accessing struct members for r-value structs is not supported yet. + local a : T = T(5, 6).x + ^*******^ +""") + + verify_exception_for_expr('T{a}(5, 6)', """ +file:?:?: Implicit arguments cannot be used with struct constructors. + local a : T = T{a}(5, 6) + ^ +""") + + +def test_unsupported_decorator(): + verify_exception(""" +@external +func foo(): + return() end """, """ -file:?:?: Tuples are not supported yet. - local x : (felt, felt) - ^**********^ +file:?:?: Unsupported decorator: 'external'. +@external +^*******^ """) + + +def test_skipped_functions(): + files = {'module': """ +func func0(): + tempvar x = 0 + return () +end +func func1(): + tempvar x = 1 + return () +end +func func2(): + tempvar x = 2 + return func1() +end +""", '.': """ +from module import func2 +func2() +"""} + program = preprocess_codes( + codes=[(files['.'], '.')], + pass_manager=default_pass_manager(prime=PRIME, read_module=read_file_from_dict(files))) + assert program.format() == """\ +[ap] = 1; ap++ +ret +[ap] = 2; ap++ +call rel -5 +ret +call rel -5 +""" + program = preprocess_codes( + codes=[(files['.'], '.')], + pass_manager=default_pass_manager( + prime=PRIME, + read_module=read_file_from_dict(files), + opt_unused_functions=False)) + assert program.format() == """\ +[ap] = 0; ap++ +ret +[ap] = 1; ap++ +ret +[ap] = 2; ap++ +call rel -5 +ret +call rel -5 +""" diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test_utils.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test_utils.py index b395975e..3ca17a83 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test_utils.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test_utils.py @@ -3,8 +3,11 @@ import pytest +from starkware.cairo.lang.compiler.preprocessor.default_pass_manager import default_pass_manager +from starkware.cairo.lang.compiler.preprocessor.pass_manager import PassManager +from starkware.cairo.lang.compiler.preprocessor.preprocess_codes import preprocess_codes from starkware.cairo.lang.compiler.preprocessor.preprocessor import ( - PreprocessedProgram, Preprocessor, preprocess_codes) + PreprocessedProgram, Preprocessor) from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.test_utils import read_file_from_dict @@ -15,31 +18,56 @@ TEST_SCOPE = ScopedName.from_string('test_scope') +def strip_comments_and_linebreaks(program: str): + """ + Removes all comments and empty lines from the given program. + """ + program = re.sub(r'\s*#.*\n', '\n', program) + return re.sub('\n+', '\n', program) + + def default_read_module(module_name: str): raise Exception( f'Error: trying to read module {module_name}, no reading algorithm provided.') def preprocess_str( - code: str, prime: int, main_scope: Optional[ScopedName] = None) -> PreprocessedProgram: + code: str, prime: int, main_scope: Optional[ScopedName] = None, + preprocessor_cls: Optional[Type[Preprocessor]] = None) -> PreprocessedProgram: + return preprocess_str_ex( + code=code, + pass_manager=default_pass_manager( + prime=prime, read_module=default_read_module, preprocessor_cls=preprocessor_cls), + main_scope=main_scope) + + +def preprocess_str_ex( + code: str, pass_manager: PassManager, + main_scope: Optional[ScopedName] = None) -> PreprocessedProgram: if main_scope is None: main_scope = TEST_SCOPE return preprocess_codes( - [(code, '')], prime, read_module=default_read_module, main_scope=main_scope) + [(code, '')], + pass_manager=pass_manager, + main_scope=main_scope) def verify_exception( code: str, error: str, files: Dict[str, str] = {}, main_scope: Optional[ScopedName] = None, - exc_type=PreprocessorError, preprocessor_cls: Optional[Type[Preprocessor]] = None): + exc_type=PreprocessorError, pass_manager: Optional[PassManager] = None): """ Verifies that compiling the code results in the given error. """ if main_scope is None: main_scope = TEST_SCOPE + if pass_manager is None: + pass_manager = default_pass_manager(prime=PRIME, read_module=read_file_from_dict(files)) + with pytest.raises(exc_type) as e: preprocess_codes( - [(code, '')], prime=PRIME, read_module=read_file_from_dict(files), - main_scope=main_scope, preprocessor_cls=preprocessor_cls) + codes=[(code, '')], + pass_manager=pass_manager, + main_scope=main_scope) # Remove line and column information from the error using a regular expression. assert re.sub(':[0-9]+:[0-9]+', 'file:?:?', str(e.value)) == error.strip() diff --git a/src/starkware/cairo/lang/compiler/preprocessor/unique_labels_test.py b/src/starkware/cairo/lang/compiler/preprocessor/unique_labels_test.py index ecd27a01..18cf361c 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/unique_labels_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/unique_labels_test.py @@ -20,6 +20,10 @@ def test_unique_label_creator(): end end end +func main(): + B.foo(1, 2) + ret +end """, prime=PRIME) assert program.format() == """\ jmp rel 10 if [fp + (-4)] != 0 @@ -33,4 +37,8 @@ def test_unique_label_creator(): ret [ap] = 3; ap++ ret +[ap] = 1; ap++ +[ap] = 2; ap++ +call rel -22 +ret """ diff --git a/src/starkware/cairo/lang/compiler/program.py b/src/starkware/cairo/lang/compiler/program.py index 7ab72d52..5b3e8bcd 100644 --- a/src/starkware/cairo/lang/compiler/program.py +++ b/src/starkware/cairo/lang/compiler/program.py @@ -2,9 +2,8 @@ import string from abc import ABC, abstractmethod from dataclasses import field -from typing import ClassVar, Dict, List, Optional, Type, Union +from typing import Dict, List, Optional, Type, Union -import marshmallow import marshmallow.fields as mfields import marshmallow_dataclass @@ -17,6 +16,7 @@ from starkware.cairo.lang.compiler.preprocessor.flow import FlowTrackingDataActual, ReferenceManager from starkware.cairo.lang.compiler.references import Reference from starkware.cairo.lang.compiler.scoped_name import ScopedName, ScopedNameAsStr +from starkware.starkware_utils.validated_dataclass import SerializableMarshmallowDataclass @dataclasses.dataclass @@ -67,8 +67,8 @@ def run_validity_checks(self): 'Invalid main() address.' -@marshmallow_dataclass.dataclass -class Program(ProgramBase): +@marshmallow_dataclass.dataclass(repr=False) +class Program(ProgramBase, SerializableMarshmallowDataclass): prime: int data: List[int] hints: Dict[int, CairoHint] @@ -79,7 +79,6 @@ class Program(ProgramBase): # Holds all the allocated references in the program. reference_manager: ReferenceManager debug_info: Optional[DebugInfo] = None - Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema def stripped(self) -> StrippedProgram: assert self.main is not None @@ -91,11 +90,15 @@ def stripped(self) -> StrippedProgram: ) def get_identifier( - self, name: Union[str, ScopedName], expected_type: Type[IdentifierDefinition]): + self, name: Union[str, ScopedName], expected_type: Type[IdentifierDefinition], + full_name_lookup: Optional[bool] = None): scoped_name = name if isinstance(name, ScopedName) else ScopedName.from_string(name) - result = self.identifiers.search( - accessible_scopes=[self.main_scope], - name=scoped_name) + if full_name_lookup is True: + result = self.identifiers.root.get(scoped_name) + else: + result = self.identifiers.search( + accessible_scopes=[self.main_scope], + name=scoped_name) result.assert_fully_parsed() identifier_definition = result.identifier_definition assert isinstance(identifier_definition, expected_type), ( @@ -103,11 +106,13 @@ def get_identifier( f'found {identifier_definition.TYPE}.') # type: ignore return identifier_definition - def get_label(self, name: Union[str, ScopedName]): - return self.get_identifier(name, LabelDefinition).pc + def get_label(self, name: Union[str, ScopedName], full_name_lookup: Optional[bool] = None): + return self.get_identifier( + name=name, expected_type=LabelDefinition, full_name_lookup=full_name_lookup).pc - def get_const(self, name: Union[str, ScopedName]): - return self.get_identifier(name, ConstDefinition).value + def get_const(self, name: Union[str, ScopedName], full_name_lookup: Optional[bool] = None): + return self.get_identifier( + name=name, expected_type=ConstDefinition, full_name_lookup=full_name_lookup).value def get_reference_binds(self, name: Union[str, ScopedName]) -> List[Reference]: """ diff --git a/src/starkware/cairo/lang/compiler/references.py b/src/starkware/cairo/lang/compiler/references.py index f1475eb6..d97b8f02 100644 --- a/src/starkware/cairo/lang/compiler/references.py +++ b/src/starkware/cairo/lang/compiler/references.py @@ -4,7 +4,7 @@ import marshmallow import marshmallow_dataclass -from starkware.cairo.lang.compiler.ast.cairo_types import CairoType +from starkware.cairo.lang.compiler.ast.cairo_types import CairoType, TypePointer from starkware.cairo.lang.compiler.ast.expr import ( ExprCast, ExprConst, ExprDeref, Expression, ExprOperator, ExprReg) from starkware.cairo.lang.compiler.error_handling import Location @@ -23,19 +23,19 @@ def __init__(self, message): def create_simple_ref_expr( reg: Register, offset: int, cairo_type: CairoType, - location: Optional[Location]) -> ExprCast: + location: Optional[Location]) -> Expression: """ - Creates an expression of the form 'cast([reg + offset], cairo_type)'. + Creates an expression of the form '[cast(reg + offset, cairo_type*)]'. """ - return ExprCast( - ExprDeref( - addr=ExprOperator( + return ExprDeref( + addr=ExprCast( + expr=ExprOperator( a=ExprReg(reg=reg, location=location), op='+', b=ExprConst(val=offset, location=location), location=location), + dest_type=TypePointer(pointee=cairo_type, location=location), location=location), - dest_type=cairo_type, location=location) diff --git a/src/starkware/cairo/lang/compiler/resolve_search_result.py b/src/starkware/cairo/lang/compiler/resolve_search_result.py new file mode 100644 index 00000000..0387a5de --- /dev/null +++ b/src/starkware/cairo/lang/compiler/resolve_search_result.py @@ -0,0 +1,45 @@ +from starkware.cairo.lang.compiler.constants import SIZE_CONSTANT +from starkware.cairo.lang.compiler.identifier_definition import ( + ConstDefinition, DefinitionError, IdentifierDefinition, ReferenceDefinition, StructDefinition) +from starkware.cairo.lang.compiler.identifier_manager import ( + IdentifierError, IdentifierManager, IdentifierSearchResult) +from starkware.cairo.lang.compiler.offset_reference import OffsetReferenceDefinition + + +def resolve_search_result( + search_result: IdentifierSearchResult, + identifiers: IdentifierManager) -> IdentifierDefinition: + """ + Returns a fully parsed identifier definition for the given identifier search result. + If search_result contains a reference with non_parsed data, returns an instance of + OffsetReferenceDefinition. + """ + identifier_definition = search_result.identifier_definition + + if len(search_result.non_parsed) == 0: + return identifier_definition + + if isinstance(identifier_definition, StructDefinition): + if search_result.non_parsed == SIZE_CONSTANT: + return ConstDefinition(value=identifier_definition.size) + + member_def = identifier_definition.members.get(search_result.non_parsed.path[0]) + struct_name = identifier_definition.full_name + if member_def is None: + raise DefinitionError( + f"'{search_result.non_parsed}' is not a member of '{struct_name}'.") + + if len(search_result.non_parsed) > 1: + raise IdentifierError( + f"Unexpected '.' after '{struct_name + search_result.non_parsed.path[0]}' which is " + f'{member_def.TYPE}.') + + identifier_definition = member_def + elif isinstance(identifier_definition, ReferenceDefinition): + identifier_definition = OffsetReferenceDefinition( + parent=identifier_definition, + member_path=search_result.non_parsed) + else: + search_result.assert_fully_parsed() + + return identifier_definition diff --git a/src/starkware/cairo/lang/compiler/resolve_search_result_test.py b/src/starkware/cairo/lang/compiler/resolve_search_result_test.py new file mode 100644 index 00000000..a5fb1d21 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/resolve_search_result_test.py @@ -0,0 +1,36 @@ +import pytest + +from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt +from starkware.cairo.lang.compiler.identifier_definition import MemberDefinition, StructDefinition +from starkware.cairo.lang.compiler.identifier_manager import ( + IdentifierError, IdentifierManager, IdentifierSearchResult) +from starkware.cairo.lang.compiler.resolve_search_result import resolve_search_result +from starkware.cairo.lang.compiler.scoped_name import ScopedName + +scope = ScopedName.from_string + + +def test_resolve_search_result(): + struct_def = StructDefinition( + full_name=scope('T'), + members={ + 'a': MemberDefinition(offset=0, cairo_type=TypeFelt()), + + 'b': MemberDefinition(offset=1, cairo_type=TypeFelt()), + }, + size=2, + ) + + identifier_dict = { + struct_def.full_name: struct_def, + } + + identifier = IdentifierManager.from_dict(identifier_dict) + + with pytest.raises(IdentifierError, match="Unexpected '.' after 'T.a' which is member"): + resolve_search_result( + search_result=IdentifierSearchResult( + identifier_definition=struct_def, + canonical_name=struct_def.full_name, + non_parsed=scope('a.z')), + identifiers=identifier) diff --git a/src/starkware/cairo/lang/compiler/substitute_identifiers.py b/src/starkware/cairo/lang/compiler/substitute_identifiers.py index fb22c58f..20c81d4b 100644 --- a/src/starkware/cairo/lang/compiler/substitute_identifiers.py +++ b/src/starkware/cairo/lang/compiler/substitute_identifiers.py @@ -1,9 +1,13 @@ from typing import Callable, Optional, Union -from starkware.cairo.lang.compiler.ast.cairo_types import CairoType +from starkware.cairo.lang.compiler.ast.cairo_types import CairoType, TypeStruct from starkware.cairo.lang.compiler.ast.expr import ( - ExprCast, ExprConst, Expression, ExprFutureLabel, ExprIdentifier) + ExprCast, ExprConst, Expression, ExprFutureLabel, ExprIdentifier, ExprTuple) +from starkware.cairo.lang.compiler.ast.expr_func_call import ExprFuncCall +from starkware.cairo.lang.compiler.ast.rvalue import RvalueFuncCall from starkware.cairo.lang.compiler.expression_transformer import ExpressionTransformer +from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.cairo.lang.compiler.type_casts import CairoTypeError GetIdentifierCallback = Callable[[ExprIdentifier], Union[int, Expression]] ResolveTypeCallback = Optional[Callable[[CairoType], CairoType]] @@ -34,6 +38,35 @@ def visit_ExprCast(self, expr: ExprCast): notes=expr.notes, location=expr.location) + def visit_RvalueFuncCall(self, rvalue: RvalueFuncCall): + # Same as super().RvalueFuncCall, except that we don't visit rvalue.func_ident. + # The reason is that function names do not constitute as expressions in Cairo, + # and visiting them in this visitor results in an error. + return RvalueFuncCall( + func_ident=rvalue.func_ident, + arguments=self.visit_ArgList(rvalue.arguments), + implicit_arguments=None if rvalue.implicit_arguments is None else self.visit_ArgList( + rvalue.implicit_arguments), + location=rvalue.location) + + def visit_ExprFuncCall(self, expr: ExprFuncCall): + # Convert ExprFuncCall to ExprCast. + rvalue = expr.rvalue + if rvalue.implicit_arguments is not None: + raise CairoTypeError( + 'Implicit arguments cannot be used with struct constructors.', + location=rvalue.implicit_arguments.location) + + struct_type = self.resolve_type_callback(TypeStruct( + scope=ScopedName.from_string(rvalue.func_ident.name), + is_fully_resolved=False, + location=expr.location)) + + return self.visit(ExprCast( + expr=ExprTuple(rvalue.arguments, location=expr.location), + dest_type=struct_type, + location=expr.location)) + def visit_ExprFutureLabel(self, expr: ExprFutureLabel): return self.visit(expr.identifier) diff --git a/src/starkware/cairo/lang/compiler/type_casts.py b/src/starkware/cairo/lang/compiler/type_casts.py index 8bb83c23..cb8b6c2e 100644 --- a/src/starkware/cairo/lang/compiler/type_casts.py +++ b/src/starkware/cairo/lang/compiler/type_casts.py @@ -1,9 +1,12 @@ -from typing import Optional +import itertools +from typing import Iterable, Optional, cast from starkware.cairo.lang.compiler.ast.cairo_types import ( - CairoType, CastType, TypeFelt, TypePointer, TypeStruct) -from starkware.cairo.lang.compiler.ast.expr import ExprDeref, Expression + CairoType, CastType, TypeFelt, TypePointer, TypeStruct, TypeTuple) +from starkware.cairo.lang.compiler.ast.expr import Expression, ExprTuple from starkware.cairo.lang.compiler.error_handling import LocationError +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager +from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition FELT_STAR = TypePointer(pointee=TypeFelt()) @@ -13,8 +16,8 @@ class CairoTypeError(LocationError): def check_cast( - src_type: CairoType, dest_type: CairoType, expr: Optional[Expression] = None, - cast_type: CastType = CastType.EXPLICIT) -> bool: + src_type: CairoType, dest_type: CairoType, identifier_manager: IdentifierManager, + expr: Optional[Expression] = None, cast_type: CastType = CastType.EXPLICIT) -> bool: """ Returns true if the given expression can be casted from src_type to dest_type according to the given 'cast_type'. @@ -46,14 +49,37 @@ def check_cast( return False # CastType.EXPLICIT checks: + assert expr is not None, f'CastType.EXPLICIT requires expr != None.' - # Allow casting to T if the expression is a dereference expression (that is, of the form [...]). - if isinstance(dest_type, TypeStruct): - assert expr is not None, 'expr must be specified with CastType.EXPLICIT.' - if not isinstance(expr, ExprDeref): + if isinstance(src_type, TypeTuple) and isinstance(dest_type, TypeStruct): + struct_def = get_struct_definition( + struct_name=dest_type.resolved_scope, identifier_manager=identifier_manager) + + n_src_members = len(src_type.members) + n_dest_members = len(struct_def.members) + if n_src_members != n_dest_members: raise CairoTypeError( - f"Cannot cast to '{dest_type.format()}' since the expression has no address.", + f"""\ +Cannot cast an expression of type '{src_type.format()}' to '{dest_type.format()}'. +The former has {n_src_members} members while the latter has {n_dest_members} members.""", location=expr.location) + + src_exprs = cast( + Iterable, expr.members.args if isinstance(expr, ExprTuple) else + itertools.repeat(expr)) + + for (src_expr, src_member_type, dest_member) in zip( + src_exprs, src_type.members, struct_def.members.values()): + dest_member_type = dest_member.cairo_type + if not check_cast( + src_type=src_member_type, dest_type=dest_member_type, + identifier_manager=identifier_manager, expr=src_expr, + cast_type=CastType.ASSIGN): + + raise CairoTypeError( + f"Cannot cast '{src_member_type.format()}' to '{dest_member_type.format()}'.", + location=src_expr.location) + return True assert cast_type is CastType.EXPLICIT, f'Unsupported cast type: {cast_type}.' diff --git a/src/starkware/cairo/lang/compiler/type_casts_test.py b/src/starkware/cairo/lang/compiler/type_casts_test.py index bd968239..3947dafe 100644 --- a/src/starkware/cairo/lang/compiler/type_casts_test.py +++ b/src/starkware/cairo/lang/compiler/type_casts_test.py @@ -1,6 +1,7 @@ import pytest from starkware.cairo.lang.compiler.ast.cairo_types import CastType +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager from starkware.cairo.lang.compiler.parser import parse_expr, parse_type from starkware.cairo.lang.compiler.type_casts import check_cast @@ -11,18 +12,21 @@ ['felt*', 'felt', True, True, False], ['felt*', 'T*', True, True, False], ['T*', 'felt*', True, True, True], - ['felt*', 'T', True, False, False], + ['felt*', 'T', False, False, False], ['T', 'felt*', False, False, False], ['felt', '(felt,felt)', False, False, False], ]) def test_type_casts( src: str, dest: str, explicit_cast: bool, unpacking_cast: bool, assign_cast: bool): + identifier_manager = IdentifierManager() src_type = parse_type(src) dest_type = parse_type(dest) expr = parse_expr('[ap]') actual_results = [ - check_cast(src_type=src_type, dest_type=dest_type, cast_type=cast_type, expr=expr) + check_cast( + src_type=src_type, dest_type=dest_type, identifier_manager=identifier_manager, + expr=expr, cast_type=cast_type) for cast_type in [CastType.EXPLICIT, CastType.UNPACKING, CastType.ASSIGN]] expected_results = [explicit_cast, unpacking_cast, assign_cast] assert actual_results == expected_results diff --git a/src/starkware/cairo/lang/compiler/type_system.py b/src/starkware/cairo/lang/compiler/type_system.py new file mode 100644 index 00000000..5cd88302 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/type_system.py @@ -0,0 +1,58 @@ +import dataclasses + +from starkware.cairo.lang.compiler.ast.cairo_types import ( + CairoType, TypeFelt, TypePointer, TypeStruct, TypeTuple) +from starkware.cairo.lang.compiler.ast.expr import ExprCast, Expression +from starkware.cairo.lang.compiler.expression_transformer import ExpressionTransformer + + +def mark_type_resolved(cairo_type: CairoType) -> CairoType: + """ + Marks the given type as resolved (struct names are absolute). + This function can be used after parsing a string which is known to contain resolved types. + """ + if isinstance(cairo_type, TypeFelt): + return cairo_type + elif isinstance(cairo_type, TypePointer): + return dataclasses.replace(cairo_type, pointee=mark_type_resolved(cairo_type.pointee)) + elif isinstance(cairo_type, TypeStruct): + if cairo_type.is_fully_resolved: + return cairo_type + return dataclasses.replace( + cairo_type, + is_fully_resolved=True) + elif isinstance(cairo_type, TypeTuple): + return dataclasses.replace( + cairo_type, + members=[mark_type_resolved(member) for member in cairo_type.members]) + else: + raise NotImplementedError(f'Type {type(cairo_type).__name__} is not supported.') + + +def is_type_resolved(cairo_type: CairoType) -> bool: + """ + Returns true if the type is resolved (struct names are absolute). + """ + if isinstance(cairo_type, TypeFelt): + return True + elif isinstance(cairo_type, TypePointer): + return is_type_resolved(cairo_type.pointee) + elif isinstance(cairo_type, TypeStruct): + return cairo_type.is_fully_resolved + elif isinstance(cairo_type, TypeTuple): + return all(map(is_type_resolved, cairo_type.members)) + else: + raise NotImplementedError(f'Type {type(cairo_type).__name__} is not supported.') + + +class MarkResolved(ExpressionTransformer): + def visit_ExprCast(self, expr: ExprCast): + return dataclasses.replace( + expr, expr=self.visit(expr.expr), dest_type=mark_type_resolved(expr.dest_type)) + + +def mark_types_in_expr_resolved(expr: Expression): + """ + Same as mark_type_resolved() except that it operates on all types within an expression. + """ + return MarkResolved().visit(expr) diff --git a/src/starkware/cairo/lang/compiler/type_system_visitor.py b/src/starkware/cairo/lang/compiler/type_system_visitor.py index 54aedfa0..04b2e74c 100644 --- a/src/starkware/cairo/lang/compiler/type_system_visitor.py +++ b/src/starkware/cairo/lang/compiler/type_system_visitor.py @@ -1,13 +1,18 @@ import dataclasses -from typing import Tuple, cast +from typing import Optional, Tuple from starkware.cairo.lang.compiler.ast.cairo_types import ( CairoType, TypeFelt, TypePointer, TypeStruct, TypeTuple) from starkware.cairo.lang.compiler.ast.expr import ( - ExprAddressOf, ExprAssignment, ExprCast, ExprConst, ExprDeref, Expression, ExprFutureLabel, - ExprIdentifier, ExprNeg, ExprOperator, ExprParentheses, ExprPyConst, ExprReg, ExprTuple) -from starkware.cairo.lang.compiler.ast.visitor import Visitor -from starkware.cairo.lang.compiler.expression_transformer import ExpressionTransformer + ExprAddressOf, ExprCast, ExprConst, ExprDeref, ExprDot, Expression, ExprFutureLabel, + ExprIdentifier, ExprNeg, ExprOperator, ExprParentheses, ExprPyConst, ExprReg, ExprSubscript, + ExprTuple) +from starkware.cairo.lang.compiler.error_handling import Location +from starkware.cairo.lang.compiler.expression_simplifier import ExpressionSimplifier +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager +from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition +from starkware.cairo.lang.compiler.preprocessor.identifier_aware_visitor import ( + IdentifierAwareVisitor) from starkware.cairo.lang.compiler.type_casts import CairoTypeError, check_cast @@ -17,28 +22,32 @@ def get_expr_addr(expr: Expression): return expr.addr -class TypeSystemVisitor(Visitor): +class TypeSystemVisitor(IdentifierAwareVisitor): """ Helper class for simplify_type_system(). """ - def visit_ExprConst(self, expr: ExprConst) -> Tuple[Expression, CairoType]: + def __init__(self, identifiers: Optional[IdentifierManager] = None): + super().__init__(identifiers) + self.identifiers_initalized = identifiers is not None + + def visit_ExprConst(self, expr: ExprConst) -> Tuple[ExprConst, TypeFelt]: return expr, TypeFelt(location=expr.location) - def visit_ExprPyConst(self, expr: ExprPyConst) -> Tuple[Expression, CairoType]: + def visit_ExprPyConst(self, expr: ExprPyConst) -> Tuple[ExprPyConst, TypeFelt]: return expr, TypeFelt(location=expr.location) - def visit_ExprFutureLabel(self, expr: ExprFutureLabel) -> Tuple[Expression, CairoType]: + def visit_ExprFutureLabel(self, expr: ExprFutureLabel) -> Tuple[ExprFutureLabel, TypeFelt]: return expr, TypeFelt(location=expr.identifier.location) def visit_ExprIdentifier(self, expr: ExprIdentifier) -> Tuple[Expression, CairoType]: raise CairoTypeError( f'Unexpected unresolved identifier {expr.format()}.', location=expr.location) - def visit_ExprReg(self, expr: ExprReg) -> Tuple[Expression, CairoType]: + def visit_ExprReg(self, expr: ExprReg) -> Tuple[ExprReg, TypeFelt]: return expr, TypeFelt(location=expr.location) - def visit_ExprOperator(self, expr: ExprOperator) -> Tuple[Expression, CairoType]: + def visit_ExprOperator(self, expr: ExprOperator) -> Tuple[ExprOperator, CairoType]: a_expr, a_type = self.visit(expr.a) b_expr, b_type = self.visit(expr.b) op = expr.op @@ -59,11 +68,11 @@ def visit_ExprOperator(self, expr: ExprOperator) -> Tuple[Expression, CairoType] location=expr.location) return dataclasses.replace(expr, a=a_expr, b=b_expr), result_type - def visit_ExprAddressOf(self, expr: ExprAddressOf) -> Tuple[Expression, CairoType]: + def visit_ExprAddressOf(self, expr: ExprAddressOf) -> Tuple[Expression, TypePointer]: inner_expr, inner_type = self.visit(expr.expr) return get_expr_addr(inner_expr), TypePointer(pointee=inner_type) - def visit_ExprNeg(self, expr: ExprNeg) -> Tuple[Expression, CairoType]: + def visit_ExprNeg(self, expr: ExprNeg) -> Tuple[ExprNeg, TypeFelt]: inner_expr, inner_type = self.visit(expr.val) if not isinstance(inner_type, TypeFelt): raise CairoTypeError( @@ -75,7 +84,7 @@ def visit_ExprNeg(self, expr: ExprNeg) -> Tuple[Expression, CairoType]: def visit_ExprParentheses(self, expr: ExprParentheses) -> Tuple[Expression, CairoType]: return self.visit(expr.val) - def visit_ExprDeref(self, expr: ExprDeref) -> Tuple[Expression, CairoType]: + def visit_ExprDeref(self, expr: ExprDeref) -> Tuple[ExprDeref, CairoType]: addr_expr, addr_type = self.visit(expr.addr) if isinstance(addr_type, TypeFelt): return dataclasses.replace(expr, addr=addr_expr), TypeFelt(location=expr.location) @@ -86,23 +95,152 @@ def visit_ExprDeref(self, expr: ExprDeref) -> Tuple[Expression, CairoType]: f"Cannot dereference type '{addr_type.format()}'.", location=expr.location) - def visit_ExprCast(self, expr: ExprCast) -> Tuple[Expression, CairoType]: + @staticmethod + def verify_offset_is_felt(offset_type: CairoType, offset_location: Location): + if not isinstance(offset_type, TypeFelt): + raise CairoTypeError( + 'Cannot apply subscript-operator with offset of non-felt type ' + f"'{offset_type.format()}'.", location=offset_location) + + def visit_ExprSubscript(self, expr: ExprSubscript) -> Tuple[Expression, CairoType]: + inner_expr, inner_type = self.visit(expr.expr) + offset_expr, offset_type = self.visit(expr.offset) + + if isinstance(inner_type, TypeTuple): + self.verify_offset_is_felt(offset_type, offset_expr.location) + offset_expr = ExpressionSimplifier().visit(offset_expr) + if not isinstance(offset_expr, ExprConst): + raise CairoTypeError( + 'Subscript-operator for tuples supports only constant offsets, found ' + f"'{type(offset_expr).__name__}'.", + location=offset_expr.location) + offset_value = offset_expr.val + + tuple_len = len(inner_type.members) + if not 0 <= offset_value < tuple_len: + raise CairoTypeError( + f'Tuple index {offset_value} is out of range [0, {tuple_len}).', + location=expr.location) + + item_type = inner_type.members[offset_value] + + if isinstance(inner_expr, ExprTuple): + assert len(inner_expr.members.args) == tuple_len + return ( + # Take the inner item, but keep the original expression's location. + dataclasses.replace( + inner_expr.members.args[offset_value].expr, location=expr.location), + item_type) + elif isinstance(inner_expr, ExprDeref): + # Handles pointers cast as tuples*, e.g. `[cast(ap, (felt, felt)*][0]`. + addr = inner_expr.addr + offset_in_felts = ExprConst( + val=sum(map(self.get_size, inner_type.members[:offset_value])), + location=offset_expr.location) + addr_with_offset = ExprOperator( + a=addr, op='+', b=offset_in_felts, location=expr.location) + return ExprDeref(addr=addr_with_offset, location=expr.location), item_type + else: + raise CairoTypeError( + 'Unexpected expression typed as TypeTuple. Expected ExprTuple or ExprDeref, ' + f"found '{type(inner_expr).__name__}'.", + location=expr.location) + elif isinstance(inner_type, TypePointer): + self.verify_offset_is_felt(offset_type, offset_expr.location) + try: + # If pointed type is struct, get_size could throw IdentifierErrors. We catch and + # convert them to CairoTypeErrors. + element_size = self.get_size(inner_type.pointee) + except Exception as exc: + raise CairoTypeError(str(exc), location=expr.location) + + element_size_expr = ExprConst(val=element_size, location=expr.location) + modified_offset_expr = ExprOperator( + a=offset_expr, op='*', b=element_size_expr, location=expr.location) + simplified_expr = ExprDeref( + addr=ExprOperator( + a=inner_expr, op='+', b=modified_offset_expr, location=expr.location), + location=expr.location) + + return simplified_expr, inner_type.pointee + else: + raise CairoTypeError( + 'Cannot apply subscript-operator to non-pointer, non-tuple type ' + f"'{inner_type.format()}'.", + location=expr.location) + + def verify_identifier_manager_initialized(self, location: Optional[Location]): + if self.identifiers_initalized: + return + raise CairoTypeError( + 'Identifiers must be initialized for type-simplification of dot-operator ' + 'expressions.', location=location) + + def visit_ExprDot(self, expr: ExprDot) -> Tuple[ExprDeref, CairoType]: + self.verify_identifier_manager_initialized(location=expr.location) + inner_expr, inner_type = self.visit(expr.expr) + if isinstance(inner_type, TypePointer): + if not isinstance(inner_type.pointee, TypeStruct): + raise CairoTypeError( + f'Cannot apply dot-operator to pointer-to-non-struct type ' + f"'{inner_type.format()}'.", location=expr.location) + # Allow for . as ->, once. + inner_type = inner_type.pointee + elif isinstance(inner_type, TypeStruct): + if isinstance(inner_expr, ExprTuple): + raise CairoTypeError( + 'Accessing struct members for r-value structs is not supported yet.', + location=expr.location) + # Get the address, to evaluate . as ->. + inner_expr = get_expr_addr(inner_expr) + else: + raise CairoTypeError( + f"Cannot apply dot-operator to non-struct type '{inner_type.format()}'.", + location=expr.location) + + try: + struct_def = get_struct_definition( + struct_name=inner_type.resolved_scope, identifier_manager=self.identifiers) + except Exception as exc: + raise CairoTypeError(str(exc), location=expr.location) + + if expr.member.name not in struct_def.members: + raise CairoTypeError( + f"Member '{expr.member.name}' does not appear in definition of struct " + f"'{inner_type.format()}'.", location=expr.location) + member_definition = struct_def.members[expr.member.name] + member_type = member_definition.cairo_type + member_offset = member_definition.offset + + if member_offset == 0: + simplified_expr = ExprDeref(addr=inner_expr, location=expr.location) + else: + mem_offset_expr = ExprConst(val=member_offset, location=expr.location) + simplified_expr = ExprDeref( + addr=ExprOperator(a=inner_expr, op='+', b=mem_offset_expr, location=expr.location), + location=expr.location) + + return simplified_expr, member_type + + def visit_ExprCast(self, expr: ExprCast) -> Tuple[Expression, CairoType]: + inner_expr, src_type = self.visit(expr.expr) dest_type = expr.dest_type if not check_cast( - src_type=inner_type, dest_type=dest_type, expr=inner_expr, - cast_type=expr.cast_type): + src_type=src_type, dest_type=dest_type, identifier_manager=self.identifiers, + expr=inner_expr, cast_type=expr.cast_type): raise CairoTypeError( - f"Cannot cast '{inner_type.format()}' to '{dest_type.format()}'.", + f"Cannot cast '{src_type.format()}' to '{dest_type.format()}'.", location=expr.location) - # Remove the cast() from the expression. - return inner_expr, dest_type + # Remove the cast() from the expression, but keep its original location. + return dataclasses.replace(inner_expr, location=expr.location), dest_type - def visit_ExprTuple(self, expr: ExprTuple) -> Tuple[Expression, CairoType]: + def visit_ExprTuple(self, expr: ExprTuple) -> Tuple[ExprTuple, TypeTuple]: args = expr.members.args - member_expr_types = [self.visit(cast(ExprAssignment, arg).expr) for arg in args] + # Call visit on each member to obtain a list of the form (expr, type). + member_expr_types = [self.visit(arg.expr) for arg in args] result_members = [ dataclasses.replace(arg, expr=expr) for arg, (expr, _) in zip(args, member_expr_types)] result_expr = dataclasses.replace( @@ -113,56 +251,19 @@ def visit_ExprTuple(self, expr: ExprTuple) -> Tuple[Expression, CairoType]: return result_expr, cairo_type -def simplify_type_system(expr: Expression) -> Tuple[Expression, CairoType]: +def simplify_type_system( + expr: Expression, + identifiers: Optional[IdentifierManager] = None) -> Tuple[Expression, CairoType]: """ Given an expression returns a type-simplified expression and its Cairo type. - This includes, checking types in operations and removing casts. - For example, for the input [cast(fp, T*)] the result will be ([fp], T). - """ - return TypeSystemVisitor().visit(expr) - - -def mark_type_resolved(cairo_type: CairoType) -> CairoType: - """ - Marks the given type as resolved (struct names are absolute). - This function can be used after parsing a string which is known to contain resolved types. - """ - if isinstance(cairo_type, TypeFelt): - return cairo_type - elif isinstance(cairo_type, TypePointer): - return dataclasses.replace(cairo_type, pointee=mark_type_resolved(cairo_type.pointee)) - elif isinstance(cairo_type, TypeStruct): - if cairo_type.is_fully_resolved: - return cairo_type - return dataclasses.replace( - cairo_type, - is_fully_resolved=True) - else: - raise NotImplementedError(f'Type {type(cairo_type).__name__} is not supported.') - - -def is_type_resolved(cairo_type: CairoType) -> bool: - """ - Returns true if the type is resolved (struct names are absolute). - """ - if isinstance(cairo_type, TypeFelt): - return True - elif isinstance(cairo_type, TypePointer): - return is_type_resolved(cairo_type.pointee) - elif isinstance(cairo_type, TypeStruct): - return cairo_type.is_fully_resolved - else: - raise NotImplementedError(f'Type {type(cairo_type).__name__} is not supported.') - - -class MarkResolved(ExpressionTransformer): - def visit_ExprCast(self, expr: ExprCast): - return dataclasses.replace( - expr, expr=self.visit(expr.expr), dest_type=mark_type_resolved(expr.dest_type)) - - -def mark_types_in_expr_resolved(expr: Expression): - """ - Same as mark_type_resolved() except that it operates on all types within an expression. + This includes checking types in operations, removing casts, and expanding dot and subscript + operators. For example: + - expr=[cast(fp, T*)] will be transformed into ([fp], T); + - If T is a struct type with member x of type S at offset 2, then expr=[cast(fp, T*)].x will + be transformed into ([[fp] + 2], S); + - If T is a struct of size 3, then expr=cast(fp, T*)[5] will be transformed into + ([fp + 5 * 3], T). + In the second and third examples, the defintion of struct T is looked up, and must be present, + in the IdentifierManager 'identifiers'. """ - return MarkResolved().visit(expr) + return TypeSystemVisitor(identifiers=identifiers).visit(expr) diff --git a/src/starkware/cairo/lang/compiler/type_system_visitor_test.py b/src/starkware/cairo/lang/compiler/type_system_visitor_test.py index ee11296e..ac2431c4 100644 --- a/src/starkware/cairo/lang/compiler/type_system_visitor_test.py +++ b/src/starkware/cairo/lang/compiler/type_system_visitor_test.py @@ -1,60 +1,293 @@ import re +from typing import Optional import pytest +from starkware.cairo.lang.compiler.ast.ast_objects_test_utils import remove_parentheses from starkware.cairo.lang.compiler.ast.cairo_types import ( - TypeFelt, TypePointer, TypeStruct, TypeTuple) + CairoType, TypeFelt, TypePointer, TypeStruct, TypeTuple) +from starkware.cairo.lang.compiler.ast_objects_test import remove_parentheses +from starkware.cairo.lang.compiler.identifier_definition import MemberDefinition, StructDefinition +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager from starkware.cairo.lang.compiler.parser import parse_expr from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.cairo.lang.compiler.type_system import mark_types_in_expr_resolved from starkware.cairo.lang.compiler.type_system_visitor import CairoTypeError, simplify_type_system scope = ScopedName.from_string +def simplify_type_system_test( + orig_expr: str, simplified_expr: str, simplified_type: CairoType, + identifiers: Optional[IdentifierManager] = None): + parsed_expr = mark_types_in_expr_resolved(parse_expr(orig_expr)) + assert simplify_type_system(parsed_expr, identifiers=identifiers) == ( + parse_expr(simplified_expr), simplified_type) + + def test_type_visitor(): - t = TypeStruct(scope=scope('T'), is_fully_resolved=False) + t = TypeStruct(scope=scope('T'), is_fully_resolved=True) t_star = TypePointer(pointee=t) t_star2 = TypePointer(pointee=t_star) - assert simplify_type_system(parse_expr('fp + 3 + [ap]')) == ( - parse_expr('fp + 3 + [ap]'), TypeFelt()) - assert simplify_type_system(parse_expr('cast(fp + 3 + [ap], T*)')) == ( - parse_expr('fp + 3 + [ap]'), t_star) + + simplify_type_system_test('fp + 3 + [ap]', 'fp + 3 + [ap]', TypeFelt()) + simplify_type_system_test('cast(fp + 3 + [ap], T*)', 'fp + 3 + [ap]', t_star) # Two casts. - assert simplify_type_system(parse_expr('cast(cast(fp, T*), felt)')) == ( - parse_expr('fp'), TypeFelt()) + simplify_type_system_test('cast(cast(fp, T*), felt)', 'fp', TypeFelt()) # Cast from T to T. - assert simplify_type_system(parse_expr('cast([cast(fp, T*)], T)')) == ( - parse_expr('[fp]'), t) + simplify_type_system_test('cast([cast(fp, T*)], T)', '[fp]', t) # Dereference. - assert simplify_type_system(parse_expr('[cast(fp, T**)]')) == ( - parse_expr('[fp]'), t_star) - assert simplify_type_system(parse_expr('[[cast(fp, T**)]]')) == ( - parse_expr('[[fp]]'), t) + simplify_type_system_test('[cast(fp, T**)]', '[fp]', t_star) + simplify_type_system_test('[[cast(fp, T**)]]', '[[fp]]', t) # Address of. - assert simplify_type_system(parse_expr('&([[cast(fp, T**)]])')) == ( - parse_expr('[fp]'), t_star) - assert simplify_type_system(parse_expr('&&[[cast(fp, T**)]]')) == ( - parse_expr('fp'), t_star2) + simplify_type_system_test('&([[cast(fp, T**)]])', '[fp]', t_star) + simplify_type_system_test('&&[[cast(fp, T**)]]', 'fp', t_star2) def test_type_tuples(): - t = TypeStruct(scope=scope('T'), is_fully_resolved=False) + t = TypeStruct(scope=scope('T'), is_fully_resolved=True) t_star = TypePointer(pointee=t) # Simple tuple. - assert simplify_type_system(parse_expr('(fp, [cast(fp, T*)], cast(fp,T*))')) == ( - parse_expr('(fp, [fp], fp)'), TypeTuple(members=[TypeFelt(), t, t_star],) - ) + simplify_type_system_test( + '(fp, [cast(fp, T*)], cast(fp,T*))', + '(fp, [fp], fp)', TypeTuple(members=[TypeFelt(), t, t_star],)) # Nested. - assert simplify_type_system(parse_expr('(fp, (), ([cast(fp, T*)],))')) == ( - parse_expr('(fp, (), ([fp],))'), TypeTuple( - members=[ - TypeFelt(), - TypeTuple(members=[]), - TypeTuple(members=[t])], - ) - ) + simplify_type_system_test('(fp, (), ([cast(fp, T*)],))', '(fp, (), ([fp],))', TypeTuple( + members=[ + TypeFelt(), + TypeTuple(members=[]), + TypeTuple(members=[t])], + )) + + +def test_type_tuples_failures(): + identifier_dict = { + scope('T'): StructDefinition( + full_name=scope('T'), + members={ + 'x': MemberDefinition(offset=0, cairo_type=TypeFelt()), + 'y': MemberDefinition(offset=1, cairo_type=TypeFelt()), + }, + size=2, + ), + } + identifiers = IdentifierManager.from_dict(identifier_dict) + + verify_exception('1 + cast((1, 2), T).x', """ +file:?:?: Accessing struct members for r-value structs is not supported yet. +1 + cast((1, 2), T).x + ^***************^ +""", identifiers=identifiers) + + +def test_type_subscript_op(): + felt_star_star = TypePointer(pointee=TypePointer(pointee=TypeFelt())) + t = TypeStruct(scope=scope('T'), is_fully_resolved=True) + t_star = TypePointer(pointee=t) + + identifier_dict = {scope('T'): StructDefinition(full_name=scope('T'), members={}, size=7)} + identifiers = IdentifierManager.from_dict(identifier_dict) + + simplify_type_system_test('cast(fp, felt*)[3]', '[fp + 3 * 1]', TypeFelt()) + simplify_type_system_test('cast(fp, felt***)[0]', '[fp + 0 * 1]', felt_star_star) + simplify_type_system_test('[cast(fp, T****)][ap][ap]', '[[[fp] + ap * 1] + ap * 1]', t_star) + simplify_type_system_test( + 'cast(fp, T**)[1][2]', '[[fp + 1 * 1] + 2 * 7]', t, identifiers=identifiers) + + # Test that 'cast(fp, T*)[2 * ap + 3]' simplifies into '[fp + (2 * ap + 3) * 7]', but without + # the parentheses. + assert simplify_type_system( + mark_types_in_expr_resolved(parse_expr('cast(fp, T*)[2 * ap + 3]')), + identifiers=identifiers) == ( + remove_parentheses(parse_expr('[fp + (2 * ap + 3) * 7]')), t) + + # Test subscript operator for tuples. + simplify_type_system_test('(cast(fp, felt**), fp, cast(fp, T*))[2]', 'fp', t_star) + simplify_type_system_test('(cast(fp, felt**), fp, cast(fp, T*))[0]', 'fp', felt_star_star) + simplify_type_system_test('(cast(fp, felt**), ap, cast(fp, T*))[3*4 - 11]', 'ap', TypeFelt()) + simplify_type_system_test('[cast(ap, (felt, felt)*)][0]', '[ap + 0]', TypeFelt()) + simplify_type_system_test( + '[cast(ap, (T*, T, felt, T*, felt*)*)][3]', '[ap + 9]', t_star, identifiers=identifiers) + + # Test failures. + + verify_exception('(fp, fp, fp)[cast(ap, felt*)]', """ +file:?:?: Cannot apply subscript-operator with offset of non-felt type 'felt*'. +(fp, fp, fp)[cast(ap, felt*)] + ^*************^ +""") + + verify_exception('(fp, fp, fp)[[ap]]', """ +file:?:?: Subscript-operator for tuples supports only constant offsets, found 'ExprDeref'. +(fp, fp, fp)[[ap]] + ^**^ +""") + + # The simplifier in TypeSystemVisitor cannot access PRIME, so PyConsts are unsimplified. + verify_exception('(fp, fp, fp)[%[1%]]', """ +file:?:?: Subscript-operator for tuples supports only constant offsets, found 'ExprPyConst'. +(fp, fp, fp)[%[1%]] + ^***^ +""") + + verify_exception('(fp, fp, fp)[3]', """ +file:?:?: Tuple index 3 is out of range [0, 3). +(fp, fp, fp)[3] +^*************^ +""") + + verify_exception('[cast(fp, (T*, T, felt)*)][-1]', """ +file:?:?: Tuple index -1 is out of range [0, 3). +[cast(fp, (T*, T, felt)*)][-1] +^****************************^ +""") + + verify_exception('cast(fp, felt)[0]', """ +file:?:?: Cannot apply subscript-operator to non-pointer, non-tuple type 'felt'. +cast(fp, felt)[0] +^***************^ +""") + + verify_exception('[cast(fp, T*)][0]', """ +file:?:?: Cannot apply subscript-operator to non-pointer, non-tuple type 'T'. +[cast(fp, T*)][0] +^***************^ +""") + + verify_exception('cast(fp, felt*)[[cast(ap, T*)]]', """ +file:?:?: Cannot apply subscript-operator with offset of non-felt type 'T'. +cast(fp, felt*)[[cast(ap, T*)]] + ^************^ +""") + + verify_exception('cast(fp, Z*)[0]', """ +file:?:?: Unknown identifier 'Z'. +cast(fp, Z*)[0] +^*************^ +""", identifiers=identifiers) + + verify_exception('cast(fp, T*)[0]', """ +file:?:?: Unknown identifier 'T'. +cast(fp, T*)[0] +^*************^ +""", identifiers=None) + + +def test_type_dot_op(): + """ + Tests type_system_visitor for ExprDot-s, in the following struct architecture: + + struct S: + member x : felt + member y : felt + end + + struct T: + member t : felt + member s : S + member sp : S* + end + + struct R: + member r : R* + end + """ + t = TypeStruct(scope=scope('T'), is_fully_resolved=True) + s = TypeStruct(scope=scope('S'), is_fully_resolved=True) + s_star = TypePointer(pointee=s) + r = TypeStruct(scope=scope('R'), is_fully_resolved=True) + r_star = TypePointer(pointee=r) + + identifier_dict = { + scope('T'): StructDefinition( + full_name=scope('T'), + members={ + 't': MemberDefinition(offset=0, cairo_type=TypeFelt()), + 's': MemberDefinition(offset=1, cairo_type=s), + 'sp': MemberDefinition(offset=3, cairo_type=s_star), + }, + size=4, + ), + scope('S'): StructDefinition( + full_name=scope('S'), + members={ + 'x': MemberDefinition(offset=0, cairo_type=TypeFelt()), + 'y': MemberDefinition(offset=1, cairo_type=TypeFelt()), + }, + size=2, + ), + scope('R'): StructDefinition( + full_name=scope('R'), + members={ + 'r': MemberDefinition(offset=0, cairo_type=r_star), + }, + size=1, + ), + } + + identifiers = IdentifierManager.from_dict(identifier_dict) + + for (orig_expr, simplified_expr, simplified_type) in [ + ('[cast(fp, T*)].t', '[fp]', TypeFelt()), + ('[cast(fp, T*)].s', '[fp + 1]', s), + ('[cast(fp, T*)].sp', '[fp + 3]', s_star), + ('[cast(fp, T*)].s.x', '[fp + 1]', TypeFelt()), + ('[cast(fp, T*)].s.y', '[fp + 1 + 1]', TypeFelt()), + ('[[cast(fp, T*)].sp].x', '[[fp + 3]]', TypeFelt()), + ('[cast(fp, R*)]', '[fp]', r), + ('[cast(fp, R*)].r', '[fp]', r_star), + ('[[[cast(fp, R*)].r].r].r', '[[[fp]]]', r_star), + # Test . as -> + ('cast(fp, T*).t', '[fp]', TypeFelt()), + ('cast(fp, T*).sp.y', '[[fp + 3] + 1]', TypeFelt()), + ('cast(fp, R*).r.r.r', '[[[fp]]]', r_star), + # More tests. + ('(cast(fp, T*).s)', '[fp + 1]', s), + ('(cast(fp, T*).s).x', '[fp + 1]', TypeFelt()), + ('(&(cast(fp, T*).s)).x', '[fp + 1]', TypeFelt()) + ]: + simplify_type_system_test( + orig_expr, simplified_expr, simplified_type, identifiers=identifiers) + + # Test failures. + + verify_exception('cast(fp, felt).x', """ +file:?:?: Cannot apply dot-operator to non-struct type 'felt'. +cast(fp, felt).x +^**************^ +""", identifiers=identifiers) + + verify_exception('cast(fp, felt*).x', """ +file:?:?: Cannot apply dot-operator to pointer-to-non-struct type 'felt*'. +cast(fp, felt*).x +^***************^ +""", identifiers=identifiers) + + verify_exception('cast(fp, T*).x', """ +file:?:?: Member 'x' does not appear in definition of struct 'T'. +cast(fp, T*).x +^************^ +""", identifiers=identifiers) + + verify_exception('cast(fp, Z*).x', """ +file:?:?: Unknown identifier 'Z'. +cast(fp, Z*).x +^************^ +""", identifiers=identifiers) + + verify_exception('cast(fp, T*).x', """ +file:?:?: Identifiers must be initialized for type-simplification of dot-operator expressions. +cast(fp, T*).x +^************^ +""", identifiers=None) + + verify_exception('cast(fp, Z*).x', """ +file:?:?: Type is expected to be fully resolved at this point. +cast(fp, Z*).x +^************^ +""", identifiers=identifiers, resolve_types=False) def test_type_visitor_failures(): @@ -69,9 +302,9 @@ def test_type_visitor_failures(): ^**************^ """) verify_exception('[cast(fp, T)]', """ -file:?:?: Cannot cast to 'T' since the expression has no address. +file:?:?: Cannot cast 'felt' to 'T'. [cast(fp, T)] - ^^ + ^*********^ """) verify_exception('&(cast(fp, T*) + 3)', """ file:?:?: Expression has no address. @@ -81,14 +314,12 @@ def test_type_visitor_failures(): def test_type_visitor_pointer_arithmetic(): - t = TypeStruct(scope=scope('T'), is_fully_resolved=False) + t = TypeStruct(scope=scope('T'), is_fully_resolved=True) t_star = TypePointer(pointee=t) - assert simplify_type_system(parse_expr('cast(fp, T*) + 3')) == ( - parse_expr('fp + 3'), t_star) - assert simplify_type_system(parse_expr('cast(fp, T*) - 3')) == ( - parse_expr('fp - 3'), t_star) - assert simplify_type_system(parse_expr('cast(fp, T*) - cast(3, T*)')) == ( - parse_expr('fp - 3'), TypeFelt()) + + simplify_type_system_test('cast(fp, T*) + 3', 'fp + 3', t_star) + simplify_type_system_test('cast(fp, T*) - 3', 'fp - 3', t_star) + simplify_type_system_test('cast(fp, T*) - cast(3, T*)', 'fp - 3', TypeFelt()) def test_type_visitor_pointer_arithmetic_failures(): @@ -109,11 +340,18 @@ def test_type_visitor_pointer_arithmetic_failures(): """) -def verify_exception(expr_str: str, error: str): +def verify_exception( + expr_str: str, + error: str, + identifiers: Optional[IdentifierManager] = None, + resolve_types=True): """ Verifies that calling simplify_type_system() on the code results in the given error. """ with pytest.raises(CairoTypeError) as e: - simplify_type_system(parse_expr(expr_str)) + parsed_expr = parse_expr(expr_str) + if resolve_types: + parsed_expr = mark_types_in_expr_resolved(parsed_expr) + simplify_type_system(parsed_expr, identifiers) # Remove line and column information from the error using a regular expression. assert re.sub(':[0-9]+:[0-9]+: ', 'file:?:?: ', str(e.value)) == error.strip() diff --git a/src/starkware/cairo/lang/ide/vscode-cairo/package.json b/src/starkware/cairo/lang/ide/vscode-cairo/package.json index 0090a409..fa4040bf 100644 --- a/src/starkware/cairo/lang/ide/vscode-cairo/package.json +++ b/src/starkware/cairo/lang/ide/vscode-cairo/package.json @@ -2,7 +2,7 @@ "name": "cairo", "displayName": "Cairo", "description": "Support Cairo syntax", - "version": "0.1.0", + "version": "0.2.0", "engines": { "vscode": "^1.30.0" }, diff --git a/src/starkware/cairo/lang/instances.py b/src/starkware/cairo/lang/instances.py index 7094e155..e6e19f21 100644 --- a/src/starkware/cairo/lang/instances.py +++ b/src/starkware/cairo/lang/instances.py @@ -2,8 +2,6 @@ from dataclasses import field from typing import Any, Dict -from starkware.cairo.lang.builtins.checkpoints.instance_def import ( - CELLS_PER_SAMPLE, CheckpointsInstanceDef) from starkware.cairo.lang.builtins.hash.instance_def import CELLS_PER_HASH, PedersenInstanceDef from starkware.cairo.lang.builtins.range_check.instance_def import ( CELLS_PER_RANGE_CHECK, RangeCheckInstanceDef) @@ -11,6 +9,12 @@ CELLS_PER_SIGNATURE, EcdsaInstanceDef) +@dataclasses.dataclass +class CpuInstanceDef: + # Verifies that each 'call' instruction returns, even if the called function is malicious. + safe_call: bool = True + + @dataclasses.dataclass class CairoLayout: layout_name: str = '' @@ -21,13 +25,13 @@ class CairoLayout: # The ratio between the number of public memory cells and the total number of memory cells. public_memory_fraction: int = 4 memory_units_per_step: int = 8 + cpu_instance_def: CpuInstanceDef = field(default=CpuInstanceDef()) CELLS_PER_BUILTIN = dict( pedersen=CELLS_PER_HASH, range_check=CELLS_PER_RANGE_CHECK, ecdsa=CELLS_PER_SIGNATURE, - checkpoints=CELLS_PER_SAMPLE, ) plain_instance = CairoLayout( @@ -57,9 +61,6 @@ class CairoLayout: height=256, n_hash_bits=251, ), - checkpoints=CheckpointsInstanceDef( - sample_ratio=16, - ), ) ) @@ -86,9 +87,6 @@ class CairoLayout: height=256, n_hash_bits=251, ), - checkpoints=CheckpointsInstanceDef( - sample_ratio=16, - ), ) ) diff --git a/src/starkware/cairo/lang/lang.cmake b/src/starkware/cairo/lang/lang.cmake new file mode 100644 index 00000000..2cbcd4fc --- /dev/null +++ b/src/starkware/cairo/lang/lang.cmake @@ -0,0 +1,55 @@ +python_lib(cairo_version_lib + PREFIX starkware/cairo/lang + + FILES + VERSION + version.py +) + +if (NOT DEFINED CAIRO_PYTHON_INTERPRETER) + set(CAIRO_PYTHON_INTERPRETER python3.7) +endif() + +python_venv(cairo_lang_venv + PYTHON ${CAIRO_PYTHON_INTERPRETER} + LIBS + cairo_bootloader_generate_fact_lib + cairo_common_lib + cairo_compile_lib + cairo_hash_program_lib + cairo_run_lib + cairo_script_lib + ${CAIRO_LANG_VENV_ADDITIONAL_LIBS} +) + +python_venv(cairo_lang_package_venv + PYTHON python3.7 + LIBS + cairo_bootloader_generate_fact_lib + cairo_common_lib + cairo_compile_lib + cairo_hash_program_lib + cairo_run_lib + cairo_script_lib + sharp_client_config_lib + sharp_client_lib + starknet_script_lib +) + +python_lib(cairo_instances_lib + PREFIX starkware/cairo/lang + + FILES + instances.py + ${CAIRO_INSTANCES_LIB_ADDITIONAL_FILES} + + LIBS + cairo_run_builtins_lib +) + +python_lib(cairo_constants_lib + PREFIX starkware/cairo/lang + + FILES + cairo_constants.py +) diff --git a/src/starkware/cairo/lang/package_test/run_test.sh b/src/starkware/cairo/lang/package_test/run_test.sh index 02e1187f..8f098101 100755 --- a/src/starkware/cairo/lang/package_test/run_test.sh +++ b/src/starkware/cairo/lang/package_test/run_test.sh @@ -28,3 +28,10 @@ res=$(cairo-run --program=main_compiled.json --layout=small --print_output) # Verify the result. # The number below is pedersen(1, 2) (which is the expected output of main.cairo). [[ "$res" == *"-1025514936890165471153863463586721648332140962090141185746964417035414175707"* ]] + + +# Test StarkNet compiler. +starknet-compile ${root_dir}/src/starkware/starknet/apps/amm_sample/amm_sample.cairo > /dev/null + +# Test StarkNet CLI. +starknet --help > /dev/null diff --git a/src/starkware/cairo/lang/setup.py b/src/starkware/cairo/lang/setup.py index 16f1a2b2..7e1fb6e2 100644 --- a/src/starkware/cairo/lang/setup.py +++ b/src/starkware/cairo/lang/setup.py @@ -21,18 +21,22 @@ setup_requires=['wheel'], url='https://cairo-lang.org/', package_data={ - 'starkware.cairo.lang': ['VERSION'], + 'starkware.cairo.common': ['*.cairo'], 'starkware.cairo.lang.compiler': ['cairo.ebnf'], 'starkware.cairo.lang.tracer': ['*.html', '*.css', '*.js', '*.png'], - 'starkware.cairo.common': ['*.cairo'], - 'starkware.crypto.signature': ['pedersen_params.json'], + 'starkware.cairo.lang': ['VERSION'], 'starkware.cairo.sharp': ['config.json'], + 'starkware.crypto.signature': ['pedersen_params.json'], + 'starkware.starknet': ['core/storage/*.cairo'], + 'starkware.starknet.security': ['whitelists/latest.json'], }, scripts=[ - 'starkware/cairo/lang/scripts/cairo-format', 'starkware/cairo/lang/scripts/cairo-compile', - 'starkware/cairo/lang/scripts/cairo-run', + 'starkware/cairo/lang/scripts/cairo-format', 'starkware/cairo/lang/scripts/cairo-hash-program', + 'starkware/cairo/lang/scripts/cairo-run', 'starkware/cairo/lang/scripts/cairo-sharp', + 'starkware/starknet/scripts/starknet-compile', + 'starkware/starknet/scripts/starknet', ] ) diff --git a/src/starkware/cairo/lang/tracer/profile.py b/src/starkware/cairo/lang/tracer/profile.py index 0c473195..8b068640 100644 --- a/src/starkware/cairo/lang/tracer/profile.py +++ b/src/starkware/cairo/lang/tracer/profile.py @@ -107,6 +107,7 @@ def function_id(self, name: str, inst_location: InstructionLocation) -> int: func = self._profile.function.add() func.id = func_id func.system_name = func.name = self.string_id(name) + assert inst_location.inst.input_file.filename is not None func.filename = self.string_id(inst_location.inst.input_file.filename) func.start_line = inst_location.inst.start_line return func_id @@ -126,6 +127,7 @@ def location_id(self, pc, inst_location: InstructionLocation) -> int: line = location.line.add() line.function_id = self._func_name_to_id[str(inst_location.accessible_scopes[-1])] line.line = sublocation.start_line + assert sublocation.input_file.filename is not None location.mapping_id = self.update_mapping_pc_range(sublocation.input_file.filename, pc) return self._pc_to_location_id[pc] @@ -166,6 +168,7 @@ def profile_from_tracer_data(tracer_data: TracerData): # Functions. identifiers_dict = tracer_data.program.identifiers.as_dict() + assert tracer_data.program.debug_info is not None for name, ident in identifiers_dict.items(): if not isinstance(ident, LabelDefinition): continue diff --git a/src/starkware/cairo/lang/tracer/tracer_data.py b/src/starkware/cairo/lang/tracer/tracer_data.py index 14cb40f1..68aff2e6 100644 --- a/src/starkware/cairo/lang/tracer/tracer_data.py +++ b/src/starkware/cairo/lang/tracer/tracer_data.py @@ -9,12 +9,12 @@ from starkware.cairo.lang.compiler.encode import decode_instruction from starkware.cairo.lang.compiler.expression_evaluator import ExpressionEvaluator from starkware.cairo.lang.compiler.identifier_definition import ConstDefinition, ReferenceDefinition -from starkware.cairo.lang.compiler.identifier_utils import resolve_search_result from starkware.cairo.lang.compiler.offset_reference import OffsetReferenceDefinition from starkware.cairo.lang.compiler.parser import parse_expr from starkware.cairo.lang.compiler.program import Program from starkware.cairo.lang.compiler.references import ( FlowTrackingError, SubstituteRegisterTransformer) +from starkware.cairo.lang.compiler.resolve_search_result import resolve_search_result from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.substitute_identifiers import substitute_identifiers from starkware.cairo.lang.compiler.type_system_visitor import simplify_type_system @@ -115,6 +115,7 @@ def __init__( for pc_offset, instruction_location in self.debug_info.instruction_locations.items(): loc = instruction_location.inst filename = loc.input_file.filename + assert filename is not None # If filename was not loaded yet, create a new InputCodeFile instance. if filename not in self.input_files: self.input_files[filename] = InputCodeFile(loc.input_file.get_content()) @@ -238,7 +239,7 @@ def field_element_repr(val: int, prime: int) -> str: class WatchEvaluator(ExpressionEvaluator): - def __init__(self, tracer_data: TracerData, entry: TraceEntry): + def __init__(self, tracer_data: TracerData, entry: TraceEntry[int]): super().__init__( prime=tracer_data.program.prime, ap=entry.ap, fp=entry.fp, memory=tracer_data.memory) @@ -260,7 +261,8 @@ def eval(self, expr): expr, expr_type = simplify_type_system( substitute_identifiers( expr=parse_expr(expr), - get_identifier_callback=self.get_variable)) + get_identifier_callback=self.get_variable), + identifiers=self.tracer_data.program.identifiers) if isinstance(expr_type, TypeStruct): raise NotImplementedError('Structs are not supported.') res = self.visit(expr) @@ -293,6 +295,7 @@ def get_variable(self, var: ExprIdentifier): def eval_reference(self, identifier_definition, var_name: str): pc_offset = self.tracer_data.get_pc_offset(self.pc) + assert self.tracer_data.program.debug_info is not None current_flow_tracking_data = \ self.tracer_data.program.debug_info.instruction_locations[pc_offset].flow_tracking_data try: diff --git a/src/starkware/cairo/lang/tracer/tracer_data_test.py b/src/starkware/cairo/lang/tracer/tracer_data_test.py index bd6c4d51..8d301cf2 100644 --- a/src/starkware/cairo/lang/tracer/tracer_data_test.py +++ b/src/starkware/cairo/lang/tracer/tracer_data_test.py @@ -79,10 +79,9 @@ def test_tracer_data(): assert tracer_data.get_current_identifier_values(trace[1]) == {'output_ptr': '21', 'x': '2000'} assert tracer_data.get_current_identifier_values(trace[2]) == { 'output_ptr': '21', 'x': '5000', 'y': '3000'} - # '__temp1' identifier is already available in this step, but should not be returned as its - # value is still unknown. - assert tracer_data.get_current_identifier_values(trace[3]) == \ - tracer_data.get_current_identifier_values(trace[4]) == { - 'output_ptr': '21', 'x': '5000', 'y': '3000', '__temp0': '1234'} + assert tracer_data.get_current_identifier_values(trace[3]) == { + 'output_ptr': '21', 'x': '5000', 'y': '3000'} + assert tracer_data.get_current_identifier_values(trace[4]) == { + 'output_ptr': '21', 'x': '5000', 'y': '3000', '__temp0': '1234'} assert tracer_data.get_current_identifier_values(trace[5]) == { - 'output_ptr': '21', 'x': '5000', 'y': '3000', '__temp0': '1234', '__temp1': '4321'} + 'output_ptr': '21', 'x': '5000', 'y': '3000', '__temp0': '1234'} diff --git a/src/starkware/cairo/lang/vm/CMakeLists.txt b/src/starkware/cairo/lang/vm/CMakeLists.txt index 978ac51e..b0662452 100644 --- a/src/starkware/cairo/lang/vm/CMakeLists.txt +++ b/src/starkware/cairo/lang/vm/CMakeLists.txt @@ -1,6 +1,6 @@ # WARNING: This library is used by the Cairo playground hint server. Please don't add files # unless they should be there. -python_lib(cairo_relocatable +python_lib(cairo_relocatable_lib PREFIX starkware/cairo/lang/vm FILES @@ -41,7 +41,7 @@ python_lib(cairo_vm_lib LIBS cairo_compile_lib - cairo_relocatable + cairo_relocatable_lib cairo_vm_crypto_lib starkware_python_utils_lib ) @@ -93,6 +93,7 @@ full_python_test(cairo_vm_test cairo_run_lib cairo_vm_lib starkware_python_utils_lib + starkware_python_test_utils_lib pip_marshmallow pip_marshmallow_dataclass pip_pytest diff --git a/src/starkware/cairo/lang/vm/cairo_pie.py b/src/starkware/cairo/lang/vm/cairo_pie.py index 3971cb05..99de1cf5 100644 --- a/src/starkware/cairo/lang/vm/cairo_pie.py +++ b/src/starkware/cairo/lang/vm/cairo_pie.py @@ -2,6 +2,7 @@ A CairoPie represents a position independent execution of a Cairo program. """ +import copy import dataclasses import io import json @@ -17,7 +18,7 @@ from starkware.cairo.lang.compiler.program import StrippedProgram, is_valid_builtin_name from starkware.cairo.lang.vm.memory_dict import MemoryDict from starkware.cairo.lang.vm.relocatable import RelocatableValue -from starkware.python.utils import add_counters +from starkware.python.utils import add_counters, sub_counters @dataclasses.dataclass @@ -130,19 +131,33 @@ def run_validity_checks(self): for name, size in self.builtin_instance_counter.items()), \ 'Invalid builtin_instance_counter.' - def __add__(self, other): - total_n_steps = self.n_steps + other.n_steps + def __add__(self, other: 'ExecutionResources') -> 'ExecutionResources': total_builtin_instance_counter = add_counters( self.builtin_instance_counter, other.builtin_instance_counter) - total_n_memory_holes = self.n_memory_holes + other.n_memory_holes + return ExecutionResources( - n_steps=total_n_steps, builtin_instance_counter=total_builtin_instance_counter, - n_memory_holes=total_n_memory_holes) + n_steps=self.n_steps + other.n_steps, + builtin_instance_counter=total_builtin_instance_counter, + n_memory_holes=self.n_memory_holes + other.n_memory_holes) + + def __sub__(self, other: 'ExecutionResources') -> 'ExecutionResources': + diff_builtin_instance_counter = sub_counters( + self.builtin_instance_counter, other.builtin_instance_counter) + diff_execution_resources = ExecutionResources( + n_steps=self.n_steps - other.n_steps, + builtin_instance_counter=diff_builtin_instance_counter, + n_memory_holes=self.n_memory_holes - other.n_memory_holes) + diff_execution_resources.run_validity_checks() + + return diff_execution_resources @classmethod def empty(cls): return cls(n_steps=0, builtin_instance_counter={}, n_memory_holes=0) + def copy(self) -> 'ExecutionResources': + return copy.deepcopy(self) + @dataclasses.dataclass class CairoPie: @@ -294,6 +309,10 @@ def verify_zip_format(cls, zf: zipfile.ZipFile): assert inner_files[cls.EXECUTION_RESOURCES_FILENAME].file_size < 10000, \ f'Invalid file size for {cls.EXECUTION_RESOURCES_FILENAME}.' + def get_segment(self, segment_info: SegmentInfo): + return self.memory.get_range( + RelocatableValue(segment_index=segment_info.index, offset=0), size=segment_info.size) + def verify_zip_file_prefix(fileobj): """ diff --git a/src/starkware/cairo/lang/vm/cairo_runner.py b/src/starkware/cairo/lang/vm/cairo_runner.py index 790b104d..f494ca9a 100644 --- a/src/starkware/cairo/lang/vm/cairo_runner.py +++ b/src/starkware/cairo/lang/vm/cairo_runner.py @@ -1,15 +1,15 @@ import functools -from typing import Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union -from starkware.cairo.lang.builtins.checkpoints.checkpoints_builtin_runner import ( - CheckpointsBuiltinRunner) from starkware.cairo.lang.builtins.hash.hash_builtin_runner import HashBuiltinRunner from starkware.cairo.lang.builtins.range_check.range_check_builtin_runner import ( RangeCheckBuiltinRunner) from starkware.cairo.lang.builtins.signature.signature_builtin_runner import SignatureBuiltinRunner -from starkware.cairo.lang.compiler.cairo_compile import compile_cairo, compile_cairo_files +from starkware.cairo.lang.compiler.cairo_compile import ( + compile_cairo, compile_cairo_files, get_module_reader) from starkware.cairo.lang.compiler.debug_info import DebugInfo from starkware.cairo.lang.compiler.expression_simplifier import to_field_element +from starkware.cairo.lang.compiler.preprocessor.default_pass_manager import default_pass_manager from starkware.cairo.lang.compiler.preprocessor.preprocessor import Preprocessor from starkware.cairo.lang.compiler.program import Program, ProgramBase from starkware.cairo.lang.instances import LAYOUTS @@ -66,7 +66,7 @@ def __init__( assert len(non_existing_builtins) == 0, \ f'Builtins {non_existing_builtins} are not present in layout "{self.layout}"' - if self.layout in ['small', 'dex', 'test']: + if self.layout != 'plain': builtin_factories = { 'output': lambda name, included: OutputBuiltinRunner(included=included), 'pedersen': functools.partial( @@ -79,7 +79,6 @@ def __init__( 'ecdsa': functools.partial( SignatureBuiltinRunner, ratio=instance.builtins['ecdsa'].ratio, process_signature=process_ecdsa, verify_signature=verify_ecdsa_sig), - 'checkpoints': functools.partial(CheckpointsBuiltinRunner, sample_ratio=16), } for name, factory in builtin_factories.items(): @@ -105,8 +104,14 @@ def from_file( remove_hints: bool = False, remove_builtins: bool = False, memory: MemoryDict = None, preprocessor_cls: Type[Preprocessor] = Preprocessor, proof_mode: Optional[bool] = None) -> 'CairoRunner': + module_reader = get_module_reader(cairo_path=[]) program = compile_cairo_files( - files=[filename], prime=prime, debug_info=True, preprocessor_cls=preprocessor_cls) + files=[filename], + debug_info=True, + pass_manager=default_pass_manager( + prime=prime, + read_module=module_reader.read, + preprocessor_cls=preprocessor_cls)) if remove_hints: program.hints = {} if remove_builtins: @@ -134,22 +139,22 @@ def initialize_main_entrypoint(self): Returns the value of the program counter after returning from main. """ - self.execution_public_memory: List[Tuple[int, int]] = [] + self.execution_public_memory: List[int] = [] stack: List[MaybeRelocatable] = [] for builtin_runner in self.builtin_runners.values(): stack += builtin_runner.initial_stack() - self.execution_public_memory = [(i, 0) for i in range(len(stack))] - if self.proof_mode: - if len(stack) == 0: - # Make sure [fp - 1] is always initialized. - stack = [0] + # Add the dummy last fp and pc to the public memory, so that the verifier can enforce + # [fp - 2] = fp. + stack = [self.execution_base + 2, 0] + stack + self.execution_public_memory = list(range(len(stack))) assert isinstance(self.program, Program), \ '--proof_mode cannot be used with a StrippedProgram.' self.initialize_state(self.program.start, stack) + self.initial_fp = self.initial_ap = self.execution_base + 2 return self.program_base + self.program.get_label('__end__') else: return_fp = self.segments.add() @@ -163,7 +168,9 @@ def initialize_function_entrypoint( self, entrypoint: Union[str, int], args: Sequence[MaybeRelocatable], return_fp: MaybeRelocatable = 0): end = self.segments.add() - self.initialize_state(entrypoint, list(args) + [return_fp, end]) + stack = list(args) + [return_fp, end] + self.initialize_state(entrypoint, stack) + self.initial_fp = self.initial_ap = self.execution_base + len(stack) self.final_pc = end return end @@ -172,9 +179,11 @@ def initialize_state(self, entrypoint: Union[str, int], stack: Sequence[MaybeRel # Load program. self.load_data(self.program_base, self.program.data) # Load stack. - self.initial_fp = self.initial_ap = self.load_data(self.execution_base, stack) + self.load_data(self.execution_base, stack) - def initialize_vm(self, hint_locals, vm_class=VirtualMachine): + def initialize_vm( + self, hint_locals, static_locals: Optional[Dict[str, Any]] = None, + vm_class=VirtualMachine): context = RunContext( pc=self.initial_pc, ap=self.initial_ap, @@ -183,9 +192,12 @@ def initialize_vm(self, hint_locals, vm_class=VirtualMachine): prime=self.program.prime, ) + if static_locals is None: + static_locals = {} + self.vm = vm_class( self.program, context, hint_locals=hint_locals, - static_locals=dict(segments=self.segments), + static_locals=dict(segments=self.segments, **static_locals), builtin_runners=self.builtin_runners, program_base=self.program_base, ) @@ -245,6 +257,7 @@ def run_until_next_power_of_2(self): self.run_until_steps(next_power_of_2(self.vm.current_step)) def end_run(self): + self.vm_memory.relocate_memory() self.vm.end_run() def read_return_values(self): @@ -255,8 +268,8 @@ def read_return_values(self): for builtin_runner in list(self.builtin_runners.values())[::-1]: pointer = builtin_runner.final_stack(self, pointer) # Add return values to public memory. - self.execution_public_memory += [(i, 0) for i in range( - pointer - self.execution_base, self.vm.run_context.ap - self.execution_base)] + self.execution_public_memory += list(range( + pointer - self.execution_base, self.vm.run_context.ap - self.execution_base)) def check_used_cells(self): """ @@ -280,7 +293,8 @@ def finalize_segments(self): self.segments.finalize( self.execution_base.segment_index, size=get_segment_used_size(self.execution_base.segment_index, self.vm_memory), - public_memory=self.execution_public_memory) + public_memory=[ + (x + self.execution_base.offset, 0) for x in self.execution_public_memory]) for builtin_runner in self.builtin_runners.values(): builtin_runner.finalize_segments(self) @@ -444,13 +458,16 @@ def print_output(self, output_callback=to_field_element): print() def print_info(self, relocated: bool): + print(self.get_info(relocated=relocated)) + + def get_info(self, relocated: bool) -> str: pc, ap, fp = self.vm.run_context.pc, self.vm.run_context.ap, self.vm.run_context.fp if relocated: pc = self.relocate_value(pc) ap = self.relocate_value(ap) fp = self.relocate_value(fp) - print(f"""\ + info = f"""\ Number of steps: {len(self.vm.trace)} { '' if self.original_steps is None else f'(originally, {self.original_steps})'} Used memory cells: {len(self.vm_memory)} @@ -458,20 +475,28 @@ def print_info(self, relocated: bool): pc = {pc} ap = {ap} fp = {fp} - """) + """ if self.segment_offsets is not None: - print('Segment relocation table:') + info += 'Segment relocation table:\n' for segment_index in range(self.segments.n_segments): - print(f'{segment_index:<5} {self.segment_offsets[segment_index]}') + info += f'{segment_index:<5} {self.segment_offsets[segment_index]}\n' + + return info + + def get_builtin_usage(self) -> str: + if len(self.builtin_runners) == 0: + return '' + + builtin_usage_str = '\nBuiltin usage:\n' + for name, builtin_runner in self.builtin_runners.items(): + used, size = builtin_runner.get_used_cells_and_allocated_size(self) + percentage = f'{used / size * 100:.2f}%' if size > 0 else '100%' + builtin_usage_str += f'{name:<30s} {percentage:>7s} (used {used} cells)\n' + + return builtin_usage_str def print_builtin_usage(self): - if self.builtin_runners: - print() - print('Builtin usage:') - for name, builtin_runner in self.builtin_runners.items(): - used, size = builtin_runner.get_used_cells_and_allocated_size(self) - percentage = f'{used / size * 100:.2f}%' if size > 0 else '100%' - print(f'{name:<30s} {percentage:>7s} (used {used} cells)') + print(self.get_builtin_usage()) def get_builtin_segments_info(self): builtin_segments: Dict[str, SegmentInfo] = {} @@ -567,10 +592,17 @@ def get_runner_from_code( Cairo runner and returns the runner. """ program = compile_cairo(code=code, prime=prime, debug_info=True) + return get_main_runner(program=program, hint_locals={}, layout=layout) + + +def get_main_runner(program: Program, hint_locals: Dict[str, Any], layout: str): + """ + Runs a main-entrypoint program using Cairo runner and returns the runner. + """ runner = CairoRunner(program, layout=layout) runner.initialize_segments() end = runner.initialize_main_entrypoint() - runner.initialize_vm(hint_locals={}) + runner.initialize_vm(hint_locals=hint_locals) runner.run_until_pc(end) runner.read_return_values() runner.finalize_segments_by_effective_size() diff --git a/src/starkware/cairo/lang/vm/cairo_runner_test.py b/src/starkware/cairo/lang/vm/cairo_runner_test.py index bd0fd0e0..3e6aba15 100644 --- a/src/starkware/cairo/lang/vm/cairo_runner_test.py +++ b/src/starkware/cairo/lang/vm/cairo_runner_test.py @@ -6,7 +6,7 @@ from starkware.cairo.lang.compiler.cairo_compile import compile_cairo from starkware.cairo.lang.vm.builtin_runner import InsufficientAllocatedCells from starkware.cairo.lang.vm.cairo_runner import CairoRunner, get_runner_from_code -from starkware.cairo.lang.vm.vm import VmException +from starkware.cairo.lang.vm.vm import VmException, VmExceptionBase CAIRO_FILE = os.path.join(os.path.dirname(__file__), 'test.cairo') PRIME = 2 ** 251 + 17 * 2 ** 192 + 1 @@ -97,7 +97,7 @@ def test_missing_exit_scope(): end """ with pytest.raises( - AssertionError, + VmExceptionBase, match=re.escape('Every enter_scope() requires a corresponding exit_scope().')): runner = get_runner_from_code(code, layout='small', prime=PRIME) diff --git a/src/starkware/cairo/lang/vm/crypto.py b/src/starkware/cairo/lang/vm/crypto.py index de941c1e..b39dd98a 100644 --- a/src/starkware/cairo/lang/vm/crypto.py +++ b/src/starkware/cairo/lang/vm/crypto.py @@ -2,6 +2,7 @@ from starkware.crypto.signature.fast_pedersen_hash import async_pedersen_hash_func # noqa from starkware.crypto.signature.fast_pedersen_hash import pedersen_hash # noqa +from starkware.crypto.signature.fast_pedersen_hash import pedersen_hash_func # noqa from starkware.crypto.signature.signature import verify as verify_ecdsa # noqa diff --git a/src/starkware/cairo/lang/vm/memory_dict.py b/src/starkware/cairo/lang/vm/memory_dict.py index 3e0d8109..e2421709 100644 --- a/src/starkware/cairo/lang/vm/memory_dict.py +++ b/src/starkware/cairo/lang/vm/memory_dict.py @@ -1,5 +1,5 @@ from collections import UserDict -from typing import Callable, List +from typing import Callable, Dict, List, Type from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue @@ -31,7 +31,14 @@ class MemoryDict(UserDict): * setitem: Checks that memory value is not changed. """ - def _check_element(self, num: MaybeRelocatable, name: str): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # A dict of segment relocation rules mapping a segment index to a RelocatableValue. + # See add_relocation_rule for more details. + self.relocation_rules: Dict[int, RelocatableValue] = {} + + def _check_element(self, num: MaybeRelocatable, name: str, exc_type: Type[Exception]): """ Checks that num is a valid Cairo value: positive int or relocatable. Currently, does not check that value < prime. @@ -39,20 +46,75 @@ def _check_element(self, num: MaybeRelocatable, name: str): if isinstance(num, RelocatableValue): return if not isinstance(num, int): - raise ValueError(f'{name} must be an int, not {type(num).__name__}.') + raise exc_type(f'{name} must be an int, not {type(num).__name__}.') if num < 0: - raise ValueError(f'{name} must be nonnegative. Got {num}.') + raise exc_type(f'{name} must be nonnegative. Got {num}.') + + def add_relocation_rule( + self, src_ptr: RelocatableValue, dest_ptr: RelocatableValue): + """ + Adds a relocation rule that moves values from the 'src_ptr' segment to 'dest_ptr'. + + 'src_ptr' must point to the start of a temporary segment (negative index with offset 0). + Once a relocation rule is set the memory dict relocates value on the fly, + allowing the VM to execute the assertion 'src_ptr' = 'dest_ptr'. + + Note that relocation rules are not applied to addresses during execution + and consequently adding a relocation rule does not allow the VM to + read the value at memory['dest_ptr']. + """ + assert src_ptr.segment_index < 0, f'src_ptr.segment_index must be < 0, src_ptr={src_ptr}.' + assert src_ptr.offset == 0, f'src_ptr.offset must be 0, src_ptr={src_ptr}.' + segment_index = src_ptr.segment_index + assert segment_index not in self.relocation_rules, \ + f'The segment with index {segment_index} already has a relocation rule.' + + self.relocation_rules[segment_index] = dest_ptr + + def relocate_value(self, value): + """ + Relocates a value according to the relocation rules. + + The original value is returned if the relocation rules do not apply to value. + """ + if not isinstance(value, RelocatableValue): + return value + + segment_idx = value.segment_index + if segment_idx >= 0: + return value + + relocation = self.relocation_rules.get(segment_idx) + if relocation is None: + return value + + return relocation + value.offset + + def relocate_memory(self): + """ + Relocates the memory according to the relocation rules and clears self.relocation_rules. + """ + if len(self.relocation_rules) == 0: + return + + self.data = { + self.relocate_value(addr): self.relocate_value(value) + for addr, value in self.items() + } + self.relocation_rules = {} def __getitem__(self, addr: MaybeRelocatable) -> MaybeRelocatable: - self._check_element(addr, 'Memory address') + self._check_element(addr, 'Memory address', KeyError) try: - return super().__getitem__(addr) + value = super().__getitem__(addr) except KeyError: raise UnknownMemoryError(addr) from None + return self.relocate_value(value) + def __setitem__(self, addr: MaybeRelocatable, value: MaybeRelocatable): - self._check_element(addr, 'Memory address') - self._check_element(value, 'Memory value') + self._check_element(addr, 'Memory address', KeyError) + self._check_element(value, 'Memory value', ValueError) # Additionally, check that address doesn't have a negative offset. if isinstance(addr, RelocatableValue) and addr.offset < 0: @@ -78,6 +140,9 @@ def set_without_checks(self, addr: MaybeRelocatable, value: MaybeRelocatable): self.data[addr] = value def serialize(self, field_bytes): + assert len(self.relocation_rules) == 0, \ + 'Cannot serialize a MemoryDict with active segment relocation rules.' + return b''.join( RelocatableValue.to_bytes(addr, ADDR_SIZE_IN_BYTES, 'little') + RelocatableValue.to_bytes(value, field_bytes, 'little') diff --git a/src/starkware/cairo/lang/vm/memory_dict_test.py b/src/starkware/cairo/lang/vm/memory_dict_test.py index e4fccaf8..8d534e04 100644 --- a/src/starkware/cairo/lang/vm/memory_dict_test.py +++ b/src/starkware/cairo/lang/vm/memory_dict_test.py @@ -6,68 +6,106 @@ def test_memory_dict_serialize(): - md = MemoryDict({1: 2, 3: 4, 5: 6}) + memory = MemoryDict({1: 2, 3: 4, 5: 6}) expected_serialized = bytes([ 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0]) - serialized = md.serialize(3) + serialized = memory.serialize(3) assert expected_serialized == serialized - assert MemoryDict.deserialize(serialized, 3) == md + assert MemoryDict.deserialize(serialized, 3) == memory def test_memory_dict_getitem(): - md = MemoryDict({11: 12}) + memory = MemoryDict({11: 12}) with pytest.raises(UnknownMemoryError): - md[12] + memory[12] def test_memory_dict_check_element(): - md = MemoryDict() - with pytest.raises(ValueError, match='must be an int'): - md['not a number'] = 12 - with pytest.raises(ValueError, match='must be nonnegative'): - md[-12] = 13 + memory = MemoryDict() + with pytest.raises(KeyError, match='must be an int'): + memory['not a number'] = 12 + with pytest.raises(KeyError, match='must be nonnegative'): + memory[-12] = 13 with pytest.raises(ValueError, match='The offset of a relocatable value must be nonnegative'): - md[RelocatableValue(segment_index=10, offset=-2)] = 13 + memory[RelocatableValue(segment_index=10, offset=-2)] = 13 # A value may have a negative offset. - md[13] = RelocatableValue(segment_index=10, offset=-2) + memory[13] = RelocatableValue(segment_index=10, offset=-2) def test_memory_dict_get(): - md = MemoryDict({14: 15}) - assert md.get(14, 'default') == 15 - assert md.get(1234, 'default') == 'default' - with pytest.raises(ValueError, match='must be nonnegative'): - md.get(-10, 'default') + memory = MemoryDict({14: 15}) + assert memory.get(14, 'default') == 15 + assert memory.get(1234, 'default') == 'default' + assert memory.get(-10, 'default') == 'default' # Attempting to read address with a negative offset is ok, it simply returns None. - assert md.get(RelocatableValue(segment_index=10, offset=-2)) is None + assert memory.get(RelocatableValue(segment_index=10, offset=-2)) is None def test_memory_dict_setdefault(): - md = MemoryDict({14: 15}) - md.setdefault(14, 0) - assert md[14] == 15 - md.setdefault(123, 456) - assert md[123] == 456 + memory = MemoryDict({14: 15}) + memory.setdefault(14, 0) + assert memory[14] == 15 + memory.setdefault(123, 456) + assert memory[123] == 456 with pytest.raises(ValueError, match='must be an int'): - md.setdefault(10, 'default') - with pytest.raises(ValueError, match='must be nonnegative'): - md.setdefault(-10, 123) + memory.setdefault(10, 'default') + with pytest.raises(KeyError, match='must be nonnegative'): + memory.setdefault(-10, 123) with pytest.raises(ValueError, match='The offset of a relocatable value must be nonnegative'): - md[RelocatableValue(segment_index=10, offset=-2)] = 13 + memory[RelocatableValue(segment_index=10, offset=-2)] = 13 def test_memory_dict_in(): - md = MemoryDict({1: 2, 3: 4}) - assert 1 in md - assert 2 not in md + memory = MemoryDict({1: 2, 3: 4}) + assert 1 in memory + assert 2 not in memory # Test that `in` doesn't add the value to the dict. - assert 2 not in md + assert 2 not in memory def test_memory_dict_multiple_values(): - md = MemoryDict({5: 10}) - md[5] = 10 - md[5] = 10 + memory = MemoryDict({5: 10}) + memory[5] = 10 + memory[5] = 10 with pytest.raises(InconsistentMemoryError): - md[5] = 11 + memory[5] = 11 + + +def test_segment_relocations(): + memory = MemoryDict() + + temp_segment = RelocatableValue(segment_index=-1, offset=0) + memory[5] = temp_segment + 2 + assert memory[5] == RelocatableValue(segment_index=-1, offset=2) + relocation_target = RelocatableValue(segment_index=4, offset=25) + memory.add_relocation_rule(src_ptr=temp_segment, dest_ptr=relocation_target) + assert memory[5] == relocation_target + 2 + + memory[temp_segment + 3] = 17 + memory.relocate_memory() + assert memory.data == { + 5: relocation_target + 2, + relocation_target + 3: 17, + } + + +def test_segment_relocation_failures(): + memory = MemoryDict() + + relocation_target = RelocatableValue(segment_index=4, offset=25) + with pytest.raises(AssertionError, match='src_ptr.segment_index must be < 0, src_ptr=1:2.'): + memory.add_relocation_rule(src_ptr=RelocatableValue( + segment_index=1, offset=2), dest_ptr=relocation_target) + + with pytest.raises(AssertionError, match='src_ptr.offset must be 0, src_ptr=-3:2.'): + memory.add_relocation_rule(src_ptr=RelocatableValue( + segment_index=-3, offset=2), dest_ptr=relocation_target) + + memory.add_relocation_rule(src_ptr=RelocatableValue( + segment_index=-3, offset=0), dest_ptr=relocation_target) + + with pytest.raises( + AssertionError, match='The segment with index -3 already has a relocation rule.'): + memory.add_relocation_rule(src_ptr=RelocatableValue( + segment_index=-3, offset=0), dest_ptr=relocation_target) diff --git a/src/starkware/cairo/lang/vm/memory_segments.py b/src/starkware/cairo/lang/vm/memory_segments.py index 59497a2f..fa2bce90 100644 --- a/src/starkware/cairo/lang/vm/memory_segments.py +++ b/src/starkware/cairo/lang/vm/memory_segments.py @@ -22,6 +22,8 @@ def __init__(self, memory: MemoryDict, prime: int): # A map from segment index to a list of pairs (offset, page_id) that constitute the # public memory. Note that the offset is absolute (not based on the page_id). self.public_memory_offsets: Dict[int, List[Tuple[int, int]]] = {} + # The number of temporary segments, see 'add_temp_segment' for more details. + self.n_temp_segments = 0 def add(self, size: Optional[int] = None) -> RelocatableValue: """ @@ -35,6 +37,19 @@ def add(self, size: Optional[int] = None) -> RelocatableValue: self.finalize(segment_index, size) return RelocatableValue(segment_index=segment_index, offset=0) + def add_temp_segment(self) -> RelocatableValue: + """ + Adds a new temporary segment and returns its starting location as a RelocatableValue. + + A temporary segment is a segment that is relocated before the cairo pie is produced. + """ + + self.n_temp_segments += 1 + # Temporary segments have negative segment indices that start from -1. + segment_index = -self.n_temp_segments + + return RelocatableValue(segment_index=segment_index, offset=0) + def finalize( self, segment_index: int, size: int, public_memory: Sequence[Tuple[int, int]] = []): """ @@ -50,13 +65,14 @@ def finalize_all_by_effective_size(self): """ Finalizes all segments that were not finalized yet, by computing their current used size. """ + segment_index_to_max_offset = get_segment_index_to_max_offset(memory=self.memory) for segment_index in range(self.n_segments): if segment_index in self.segment_sizes: # Segment was already finalized. continue assert segment_index not in self.public_memory_offsets - self.segment_sizes[segment_index] = get_segment_used_size(segment_index, self.memory) + self.segment_sizes[segment_index] = segment_index_to_max_offset[segment_index] + 1 self.public_memory_offsets[segment_index] = [] def relocate_segments(self) -> Dict[int, int]: @@ -143,10 +159,24 @@ def get_segment_used_size(segment_index: int, memory: MemoryDict) -> int: was accessed. """ max_offset = -1 - for addr in memory.keys(): + for addr in memory: assert isinstance(addr, RelocatableValue), \ f'Expected memory address to be relocatable value. Found: {addr}.' if addr.segment_index != segment_index: continue max_offset = max(max_offset, addr.offset) return max_offset + 1 + + +def get_segment_index_to_max_offset(memory: MemoryDict) -> Dict[int, int]: + """ + Returns a mapping between the segment indices and the maximal offset that + was accessed in each segment. + """ + segment_index_to_max_offset: Dict[int, int] = defaultdict(lambda: -1) + for addr in memory: + assert isinstance(addr, RelocatableValue), \ + f'Expected memory address to be relocatable value. Found: {addr}.' + previous_max_offset = segment_index_to_max_offset[addr.segment_index] + segment_index_to_max_offset[addr.segment_index] = max(previous_max_offset, addr.offset) + return segment_index_to_max_offset diff --git a/src/starkware/cairo/lang/vm/output_builtin_runner_test.py b/src/starkware/cairo/lang/vm/output_builtin_runner_test.py index fcbe0693..1022ff07 100644 --- a/src/starkware/cairo/lang/vm/output_builtin_runner_test.py +++ b/src/starkware/cairo/lang/vm/output_builtin_runner_test.py @@ -59,7 +59,7 @@ def test_pages(runner_and_output_runner): segment_offsets = {0: 0, 1: 10, 2: 100} assert runner.segments.get_public_memory_addresses(segment_offsets=segment_offsets) == ( [(i, 0) for i in range(len(runner.program.data))] + # Program segment. - [(10, 0)] + # Execution segment. + [(10, 0), (11, 0), (12, 0)] + # Execution segment. [(100 + offset, page_id) for offset, page_id in offset_page_pairs]) # Output segment. diff --git a/src/starkware/cairo/lang/vm/vm.py b/src/starkware/cairo/lang/vm/vm.py index 7d95e3f1..3524a005 100644 --- a/src/starkware/cairo/lang/vm/vm.py +++ b/src/starkware/cairo/lang/vm/vm.py @@ -38,7 +38,13 @@ class Operands: op1: MaybeRelocatable -class VmException(LocationError): +class VmExceptionBase(Exception): + """ + Base class for exceptions thrown by the Cairo VM. + """ + + +class VmException(LocationError, VmExceptionBase): def __init__( self, pc, inst_location: Optional[InstructionLocation], inner_exc, traceback: Optional[str] = None, notes: Optional[List[str]] = None, hint: bool = False): @@ -56,7 +62,7 @@ def __init__( self.notes += notes -class InconsistentAutoDeductionError(Exception): +class InconsistentAutoDeductionError(VmExceptionBase): def __init__(self, addr, current_value, new_value): self.addr = addr self.current_value = current_value @@ -65,7 +71,7 @@ def __init__(self, addr, current_value, new_value): f'Inconsistent auto deduction rule at address {addr}. {current_value} != {new_value}.') -class PureValueError(Exception): +class PureValueError(VmExceptionBase): def __init__(self, oper, *values): self.oper = oper self.values = values @@ -74,7 +80,7 @@ def __init__(self, oper, *values): f'Could not complete computation {oper} of non pure {values_str}.') -class HintException(Exception): +class HintException(VmExceptionBase): def __init__(self, vm, exc_type, exc_value, exc_tb): tb_exception = traceback.TracebackException(exc_type, exc_value, exc_tb) # First item in the traceback is the call to exec, remove it. @@ -156,10 +162,10 @@ def compute_op1_addr(self, instruction: Instruction, op0: Optional[MaybeRelocata elif instruction.op1_addr is Instruction.Op1Addr.AP: base_addr = self.ap elif instruction.op1_addr is Instruction.Op1Addr.IMM: - assert instruction.off2 == 1, 'In immediate mode, off2 should be 1' + assert instruction.off2 == 1, 'In immediate mode, off2 should be 1.' base_addr = self.pc elif instruction.op1_addr is Instruction.Op1Addr.OP0: - assert op0 is not None, 'op0 must be known in double dereference' + assert op0 is not None, 'op0 must be known in double dereference.' base_addr = op0 else: raise NotImplementedError('Invalid op1_register value') @@ -173,6 +179,9 @@ def get_traceback_entries(self): entries = [] fp = self.fp for _ in range(MAX_TRACEBACK_ENTRIES): + if self.memory.get(fp - 2) == fp: + break + # Get the previous fp and the return pc. fp, ret_pc = self.memory.get(fp - 2), self.memory.get(fp - 1) @@ -210,7 +219,7 @@ class CompiledHint: class VirtualMachine: def __init__( self, program: ProgramBase, run_context: RunContext, - hint_locals: dict, static_locals: dict = {}, + hint_locals: Dict[str, Any], static_locals: Optional[Dict[str, Any]] = None, builtin_runners: Dict[str, BuiltinRunner] = {}, program_base: Optional[int] = None): """ hints - a dictionary from memory addresses to an executable object. @@ -259,7 +268,7 @@ def __init__( self.skip_instruction_execution = False from starkware.python import math_utils - self.static_locals = static_locals.copy() + self.static_locals = static_locals.copy() if static_locals is not None else {} self.static_locals.update({ 'PRIME': self.prime, 'fadd': lambda a, b, p=self.prime: (a + b) % p, @@ -288,7 +297,8 @@ def load_hints(self, program: Program, program_base: MaybeRelocatable): consts=lambda pc, ap, fp, memory, hint=hint: VmConsts( context=VmConstsContext( identifiers=program.identifiers, - evaluator=ExpressionEvaluator(self.prime, ap, fp, memory).eval, + evaluator=ExpressionEvaluator( + self.prime, ap, fp, memory, program.identifiers).eval, reference_manager=program.reference_manager, flow_tracking_data=hint.flow_tracking_data, memory=memory, @@ -327,6 +337,7 @@ def enter_scope(self, new_scope_locals: Optional[dict] = None): self.exec_scopes.append({**new_scope_locals, **self.builtin_runners}) def exit_scope(self): + assert len(self.exec_scopes) > 1, 'Cannot exit main scope.' self.exec_scopes.pop() def update_registers(self, instruction: Instruction, operands: Operands): @@ -517,6 +528,8 @@ def is_zero(self, value): This function can be overridden by subclasses. """ if not isinstance(value, int): + if isinstance(value, RelocatableValue) and value.offset >= 0: + return False raise PureValueError('jmp != 0', value) return value == 0 @@ -751,8 +764,8 @@ def verify_auto_deductions(self): def end_run(self): self.verify_auto_deductions() - assert len(self.exec_scopes) == 1, \ - 'Every enter_scope() requires a corresponding exit_scope().' + if len(self.exec_scopes) != 1: + raise VmExceptionBase('Every enter_scope() requires a corresponding exit_scope().') def get_perm_range_check_limits( diff --git a/src/starkware/cairo/lang/vm/vm_consts.py b/src/starkware/cairo/lang/vm/vm_consts.py index 3e3a7694..b8182dc9 100644 --- a/src/starkware/cairo/lang/vm/vm_consts.py +++ b/src/starkware/cairo/lang/vm/vm_consts.py @@ -13,7 +13,7 @@ MissingIdentifierError) from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition from starkware.cairo.lang.compiler.preprocessor.flow import FlowTrackingData, ReferenceManager -from starkware.cairo.lang.compiler.references import Reference +from starkware.cairo.lang.compiler.references import FlowTrackingError, Reference from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.type_system_visitor import simplify_type_system from starkware.cairo.lang.vm.memory_dict import MemoryDict @@ -45,11 +45,17 @@ def __getattr__(self, name: str): if name.startswith('__'): raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'") - return self.get_or_set_value(name, None) + try: + return self.get_or_set_value(name, None) + except FlowTrackingError: + raise FlowTrackingError(f"Reference '{name}' is revoked.") from None def __setattr__(self, name: str, value): assert value is not None, 'Setting a value to None is not allowed.' - self.get_or_set_value(name, value) + try: + self.get_or_set_value(name, value) + except FlowTrackingError: + raise FlowTrackingError(f"Reference '{name}' is revoked.") from None @abstractmethod def get_or_set_value(self, name: str, set_value: Optional[MaybeRelocatable]): @@ -149,6 +155,7 @@ def handle_scope( instruction_offset=identifier.pc if isinstance(identifier, LabelDefinition) else None) handle_LabelDefinition = handle_scope + handle_FunctionDefinition = handle_scope def handle_StructDefinition( self, name: str, identifier: StructDefinition, scope: ScopedName, @@ -170,7 +177,9 @@ def handle_ReferenceDefinition( if set_value is None: expr = reference.eval( self._context.flow_tracking_data.ap_tracking) - expr, expr_type = simplify_type_system(expr) + expr, expr_type = simplify_type_system( + expr, + identifiers=self._context.identifiers) if isinstance(expr_type, TypeStruct): # If the reference is of type T, take its address and treat it as T*. assert isinstance(expr, ExprDeref), \ @@ -190,11 +199,12 @@ def handle_ReferenceDefinition( return VmConstsReference( context=self._context, struct_name=expr_type.pointee.scope, - reference_value=val, - add_addr_var=True) + reference_value=val) else: assert str(scope[-1:]) == name, 'Expecting scope to end with name.' - value, value_type = simplify_type_system(reference.value) + value, value_type = simplify_type_system( + reference.value, + identifiers=self._context.identifiers) assert isinstance(value, ExprDeref), f"""\ {scope} (= {value.format()}) does not reference memory and cannot be assigned.""" @@ -217,10 +227,9 @@ def raise_unsupported_error(self, name: ScopedName, identifier_type: str): class VmConstsReference(VmConstsBase): - def __init__(self, *, reference_value, struct_name: ScopedName, add_addr_var: bool, **kw): + def __init__(self, *, reference_value, struct_name: ScopedName, **kw): """ Constructs a VmConstsReference which allows accessing a typed reference fields. - If add_addr_var, the value of the reference itself can be accessed using self.address_. """ super().__init__(**kw) @@ -228,8 +237,14 @@ def __init__(self, *, reference_value, struct_name: ScopedName, add_addr_var: bo struct_name=struct_name, identifier_manager=self._context.identifiers)) object.__setattr__(self, '_reference_value', reference_value) - if add_addr_var: - object.__setattr__(self, 'address_', reference_value) + object.__setattr__(self, 'address_', reference_value) + + @property + def type_(self): + return VmConstsStruct( + context=self._context, + struct_definition=self._struct_definition, + ) def get_or_set_value(self, name: str, set_value: Optional[MaybeRelocatable]): """ @@ -258,8 +273,7 @@ def get_or_set_value(self, name: str, set_value: Optional[MaybeRelocatable]): return VmConstsReference( context=self._context, struct_name=expr_type.pointee.scope, - reference_value=self._context.memory[addr], - add_addr_var=True) + reference_value=self._context.memory[addr]) def is_simple_type(expr_type: CairoType) -> bool: @@ -283,6 +297,12 @@ def __init__(self, *, struct_definition: StructDefinition, **kw): super().__init__(**kw) object.__setattr__(self, '_struct_definition', struct_definition) + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return self._struct_definition == other._struct_definition and \ + self._context is other._context + def get_or_set_value(self, name: str, set_value: Optional[MaybeRelocatable]): assert set_value is None, 'Cannot change the value of a constant.' diff --git a/src/starkware/cairo/lang/vm/vm_consts_test.py b/src/starkware/cairo/lang/vm/vm_consts_test.py index 4ec89fc1..2c8c4f6f 100644 --- a/src/starkware/cairo/lang/vm/vm_consts_test.py +++ b/src/starkware/cairo/lang/vm/vm_consts_test.py @@ -15,7 +15,7 @@ from starkware.cairo.lang.compiler.preprocessor.reg_tracking import RegTrackingData from starkware.cairo.lang.compiler.references import FlowTrackingError, Reference from starkware.cairo.lang.compiler.scoped_name import ScopedName -from starkware.cairo.lang.compiler.type_system_visitor import mark_types_in_expr_resolved +from starkware.cairo.lang.compiler.type_system import mark_types_in_expr_resolved from starkware.cairo.lang.vm.vm_consts import VmConsts, VmConstsContext scope = ScopedName.from_string @@ -142,6 +142,7 @@ def test_references(): size=11, ), } + identifiers = IdentifierManager.from_dict(identifier_values) prime = 2**64 + 13 ap = 100 fp = 200 @@ -157,8 +158,8 @@ def test_references(): reference_ids=references, ) context = VmConstsContext( - identifiers=IdentifierManager.from_dict(identifier_values), - evaluator=ExpressionEvaluator(prime, ap, fp, memory).eval, + identifiers=identifiers, + evaluator=ExpressionEvaluator(prime, ap, fp, memory, identifiers).eval, reference_manager=reference_manager, flow_tracking_data=flow_tracking_data, memory=memory, @@ -238,9 +239,10 @@ def get_vm_consts(identifier_values, reference_manager, flow_tracking_data, memo """ Creates a simple VmConsts object. """ + identifiers = IdentifierManager.from_dict(identifier_values) context = VmConstsContext( - identifiers=IdentifierManager.from_dict(identifier_values), - evaluator=ExpressionEvaluator(2**64 + 13, 0, 0, memory).eval, + identifiers=identifiers, + evaluator=ExpressionEvaluator(2**64 + 13, 0, 0, memory, identifiers).eval, reference_manager=reference_manager, flow_tracking_data=flow_tracking_data, memory=memory, pc=9) return VmConsts(context=context, accessible_scopes=[ScopedName()]) @@ -256,7 +258,7 @@ def test_reference_rebinding(): reference_manager = ReferenceManager() flow_tracking_data = FlowTrackingDataActual(ap_tracking=RegTrackingData()) consts = get_vm_consts(identifier_values, reference_manager, flow_tracking_data) - with pytest.raises(FlowTrackingError, match='Reference ref revoked'): + with pytest.raises(FlowTrackingError, match="Reference 'ref' is revoked."): consts.ref flow_tracking_data = flow_tracking_data.add_reference( @@ -294,7 +296,7 @@ def test_reference_to_structs(): name=scope('ref'), ref=Reference( pc=0, - value=mark_types_in_expr_resolved(parse_expr('cast([100], T)')), + value=mark_types_in_expr_resolved(parse_expr('[cast(100, T*)]')), ap_tracking_data=RegTrackingData(group=0, offset=2), ), ) @@ -309,6 +311,8 @@ def test_reference_to_structs(): assert memory[203] == 300 assert consts.ref.x.x.address_ == 300 + assert consts.ref.type_ == consts.T + def test_missing_attributes(): identifier_values = { @@ -413,6 +417,7 @@ def test_revoked_reference(): full_name=scope('x'), cairo_type=TypeFelt(), references=[] ), } + identifiers = IdentifierManager.from_dict(identifier_values) prime = 2**64 + 13 ap = 100 fp = 200 @@ -423,13 +428,16 @@ def test_revoked_reference(): reference_ids={scope('x'): ref_id}, ) context = VmConstsContext( - identifiers=IdentifierManager.from_dict(identifier_values), - evaluator=ExpressionEvaluator(prime, ap, fp, memory).eval, + identifiers=identifiers, + evaluator=ExpressionEvaluator(prime, ap, fp, memory, identifiers).eval, reference_manager=reference_manager, flow_tracking_data=flow_tracking_data, memory=memory, pc=0) consts = VmConsts(context=context, accessible_scopes=[ScopedName()]) - with pytest.raises(FlowTrackingError, match='Failed to deduce ap.'): - assert consts.x + with pytest.raises(FlowTrackingError, match="Reference 'x' is revoked."): + consts.x + + with pytest.raises(FlowTrackingError, match="Reference 'x' is revoked."): + consts.x = 85 diff --git a/src/starkware/cairo/lang/vm/vm_test.py b/src/starkware/cairo/lang/vm/vm_test.py index 107ab8a3..346d0eda 100644 --- a/src/starkware/cairo/lang/vm/vm_test.py +++ b/src/starkware/cairo/lang/vm/vm_test.py @@ -9,6 +9,7 @@ from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue from starkware.cairo.lang.vm.vm import ( InconsistentAutoDeductionError, RunContext, VirtualMachine, VmException) +from starkware.python.test_utils import maybe_raises PRIME = 2**64 + 13 @@ -88,6 +89,23 @@ def test_jnz(): [7, 6, 5, 4, 3, 2, 1, 0] + [4, 3, 2, 1, 0] * 5 +@pytest.mark.parametrize('offset', [0, -1]) +def test_jnz_relocatables(offset: int): + code = """ + jmp body if [ap - 1] != 0 + [ap] = 0; ap++ + body: + [ap] = 1; ap++ + """ + relocatable_value = RelocatableValue(segment_index=5, offset=offset) + error_message = \ + None if relocatable_value.offset >= 0 else \ + f'Could not complete computation jmp != 0 of non pure value {relocatable_value}' + with maybe_raises(expected_exception=VmException, error_message=error_message): + vm = run_single(code, 2, ap=102, extra_mem={101: relocatable_value}) + assert vm.run_context.memory[102] == 1 + + def test_call_ret(): code = """ [fp] = 1000; ap++ diff --git a/src/starkware/cairo/sharp/config.json b/src/starkware/cairo/sharp/config.json index cabd3932..a239ed49 100644 --- a/src/starkware/cairo/sharp/config.json +++ b/src/starkware/cairo/sharp/config.json @@ -1,5 +1,5 @@ { "prover_url": "https://ropsten-v1.provingservice.io", - "verifier_address": "0x4Cd99A1FC780d874a34008fdC6da2961d540fE64", + "verifier_address": "0x2886D2A190f00aA324Ac5BF5a5b90217121D5756", "steps_limit": 1000000 } diff --git a/src/starkware/crypto/starkware/crypto/signature/signature.py b/src/starkware/crypto/starkware/crypto/signature/signature.py index 4affe864..31a9dac9 100644 --- a/src/starkware/crypto/starkware/crypto/signature/signature.py +++ b/src/starkware/crypto/starkware/crypto/signature/signature.py @@ -18,7 +18,7 @@ import json import math import os -import random +import secrets from typing import Optional, Tuple, Union from ecdsa.rfc6979 import generate_k @@ -87,8 +87,8 @@ def get_y_coordinate(stark_key_x_coordinate: int) -> int: def get_random_private_key() -> int: - # NOTE: It is IMPORTANT to use a strong random function here. - return random.randint(1, EC_ORDER - 1) + # Returns a private key in the range [1, EC_ORDER). + return secrets.randbelow(EC_ORDER - 1) + 1 def private_key_to_ec_point_on_stark_curve(priv_key: int) -> ECPoint: diff --git a/src/starkware/python/CMakeLists.txt b/src/starkware/python/CMakeLists.txt index 575aac43..6e393b66 100644 --- a/src/starkware/python/CMakeLists.txt +++ b/src/starkware/python/CMakeLists.txt @@ -5,6 +5,7 @@ python_lib(starkware_python_utils_lib math_utils.py utils.py python_dependencies.py + async_subprocess.py LIBS pip_sympy @@ -25,8 +26,10 @@ python_lib(starkware_merkle_tree_lib python_lib(starkware_python_test_utils_lib PREFIX starkware/python FILES + random_test.py test_utils.py LIBS + pip_mypy_extensions pip_pytest ) diff --git a/src/starkware/python/async_subprocess.py b/src/starkware/python/async_subprocess.py new file mode 100644 index 00000000..d8e79389 --- /dev/null +++ b/src/starkware/python/async_subprocess.py @@ -0,0 +1,18 @@ +import asyncio +from typing import List, Union + + +async def async_check_output(args: Union[str, List[str]], shell: bool = False, cwd=None, env=None): + """ + An async equivalent to subprocess.check_output(). + """ + if shell: + assert isinstance(args, str), 'args must be a string where shell=True.' + # Pass '-e' to stop after failure if args consists of multiple commands. + args = ['bash', '-e', '-c', args] + proc = await asyncio.create_subprocess_exec( + *args, cwd=cwd, env=env, stdout=asyncio.subprocess.PIPE) + return_code = await proc.wait() + assert return_code == 0 + assert proc.stdout is not None + return await proc.stdout.read() diff --git a/src/starkware/python/random_test.py b/src/starkware/python/random_test.py new file mode 100644 index 00000000..fa9de4a5 --- /dev/null +++ b/src/starkware/python/random_test.py @@ -0,0 +1,135 @@ +import functools +import inspect +import os +import random +import sys +from typing import Callable, List, Optional + +import pytest +from mypy_extensions import NamedArg + + +def _get_seeds(n_nightly_runs: int, seed: Optional[int]) -> List[int]: + n_iters_env_var = os.environ.get('RANDOM_TEST_N_RUNS') + if n_iters_env_var is None: + n_iters = n_nightly_runs if (os.environ.get('NIGHTLY_TEST') == '1') else 1 + else: + n_iters = int(n_iters_env_var) + + seed_env_var = os.environ.get('RANDOM_TEST_SEED') + if seed_env_var == 'random': + return [random.randrange(sys.maxsize) for _ in range(n_iters)] + elif seed_env_var is not None: + return [int(seed_env_var)] + elif seed is not None: + return [seed] + return [random.randrange(sys.maxsize) for _ in range(n_iters)] + + +def _print_seed(seed: int, decorator_name: str): + print(f'The seed used in the test is {seed}.') + print(f'To reproduce the results set the environment variable RANDOM_TEST_SEED to {seed}.') + print( + f"(This can be done by adding 'RANDOM_TEST_SEED={seed}' at the beginning of the command).") + print(f"Alternatively, you can add 'seed={seed}' to the '{decorator_name}' decorator") + + +def _convert_function_to_function_or_coroutine( + caller_func: Callable, callee_func: Callable) -> Callable: + """ + Gets a function `caller_func` and a function or co-routine `callee_func`. + `caller_func` is expected to yield values of the form `callee_func(...)` (which are either + values or tasks). + Converts `caller_func` into a function or coroutine that runs all the yielded calls in + caller_function and awaits them if they are tasks. + Uses the callee function to determine the name and args of the returned function. + Exceptions that were thrown will be raised into the caller function. + """ + if inspect.iscoroutinefunction(callee_func): + @functools.wraps(callee_func) + async def return_value(*args, **kwargs): + gen = caller_func(*args, **kwargs) + for run in gen: + try: + await run + except Exception as e: + gen.throw(e) + else: + @functools.wraps(callee_func) + def return_value(*args, **kwargs): + for run in caller_func(*args, **kwargs): + pass + return return_value + + +def random_test(n_nightly_runs: int = 10, seed: Optional[int] = None): + """ + A decorator for random tests that fixates the python global random object with a seed. In + non-nightly runs, the seed is constant. In nightly runs, the test will run multiple times with + random seeds. + Currently, non-nightly runs will run with a random seed. This will be changed soon. + + Assumes that the test receives a `seed` parameter. It doesn't need to do anything with it. + + The test will print the seed upon failure. + Can also receive a seed to fixate the test with. If it got a seed, it will run the test once + with that seed even on nightly runs. + + The seed can also be fixed by setting the `RANDOM_TEST_SEED` environment variable to the desired + seed. If it is set to `random` a random seed will be used. If the seed is set to a number then + the test will run only once. + If the `RANDOM_TEST_N_RUNS` environment variable is defined, the test will run that many times + (both in non-nightly runs and nightly runs). + Setting the environment variable can be done by prefixing the command line with + `RANDOM_TEST_SEED=10` for example. + """ + def convert_test_func(test_func: Callable): + seeds = _get_seeds(n_nightly_runs=n_nightly_runs, seed=seed) + + def fixate_seed_and_yield_test_run(*args, seed, **kwargs): + old_state = random.getstate() + random.seed(seed) + try: + yield test_func(*args, seed=seed, **kwargs) + except Exception: + _print_seed(seed=seed, decorator_name='random_test') + raise + finally: + random.setstate(old_state) + # We need to use pytest.mark.parametrize rather than running the test in a for loop. If we + # do the latter, pytest won't re-create the fixtures for each run. + return pytest.mark.parametrize('seed', seeds)(_convert_function_to_function_or_coroutine( + caller_func=fixate_seed_and_yield_test_run, callee_func=test_func)) + + return convert_test_func + + +def parametrize_random_object(n_nightly_runs: int = 10, seed: Optional[int] = None): + """ + A decorator for random tests that passes as a parameter a random object with a seed. In + non-nightly runs, the seed is constant. In nightly runs, the test will run multiple times with + random seeds. + Currently, non-nightly runs will run with a random seed. This will be changed soon. + + The test will print the seed upon failure. + Can also receive a seed to fixate the test with. If it got a seed, it will run the test once + with that seed even on nightly runs. + + For explanation on environment variables, read the doc of the `random_test` decorator. + """ + def convert_test_func( + test_func: Callable[[NamedArg(type=random.Random, name='random_object')], None]): + seeds = _get_seeds(n_nightly_runs=n_nightly_runs, seed=seed) + + def fixate_seed_and_yield_test_run(*args, **kwargs): + yield test_func(*args, **kwargs) + + return pytest.mark.parametrize( + 'random_object', + [random.Random(seed) for seed in seeds], + ids=[f'Random({seed})' for seed in seeds], + )( + _convert_function_to_function_or_coroutine( + caller_func=fixate_seed_and_yield_test_run, callee_func=test_func)) + + return convert_test_func diff --git a/src/starkware/python/utils.py b/src/starkware/python/utils.py index 06b16280..b5087930 100644 --- a/src/starkware/python/utils.py +++ b/src/starkware/python/utils.py @@ -1,10 +1,11 @@ import asyncio +import itertools import os import random import re import subprocess from collections import UserDict -from typing import List, Optional +from typing import Any, Iterable, List, Optional def get_package_path(): @@ -190,3 +191,13 @@ async def cancel_futures(*futures: asyncio.Future): await future except asyncio.CancelledError: pass + + +def safe_zip(*iterables: Iterable[Any]) -> Iterable: + """ + Zips iterables. Makes sure the lengths of all iterables are equal. + """ + sentinel = object() + for combo in itertools.zip_longest(*iterables, fillvalue=sentinel): + assert sentinel not in combo, 'Iterables to safe_zip are not equal in length.' + yield combo diff --git a/src/starkware/python/utils_test.py b/src/starkware/python/utils_test.py index ca686ae0..b1f8551d 100644 --- a/src/starkware/python/utils_test.py +++ b/src/starkware/python/utils_test.py @@ -2,7 +2,7 @@ import pytest -from starkware.python.utils import WriteOnceDict, indent, unique +from starkware.python.utils import WriteOnceDict, indent, safe_zip, unique def test_indent(): @@ -24,3 +24,19 @@ def test_write_once_dict(): with pytest.raises(AssertionError, match=re.escape( f"Trying to set key=5 to 'b' but key=5 is already set to 'None'.")): d[key] = 'b' + + +def test_safe_zip(): + # Test empty case. + assert list(safe_zip()) == list(zip()) + + # Test equal-length iterables (including a generator). + assert ( + list(safe_zip((i for i in range(3)), range(3, 6), [1, 2, 3])) == + list(zip((i for i in range(3)), range(3, 6), [1, 2, 3]))) + + # Test unequal-length iterables. + test_cases = [[range(4), range(3)], [[], range(3)]] + for iterables in test_cases: + with pytest.raises(AssertionError, match='Iterables to safe_zip are not equal in length.'): + list(safe_zip(*iterables)) # Consume generator to get to the error. diff --git a/src/starkware/starknet/CMakeLists.txt b/src/starkware/starknet/CMakeLists.txt new file mode 100644 index 00000000..3efe01cd --- /dev/null +++ b/src/starkware/starknet/CMakeLists.txt @@ -0,0 +1,8 @@ +add_subdirectory(cli) +add_subdirectory(compiler) +add_subdirectory(core) +add_subdirectory(definitions) +add_subdirectory(public) +add_subdirectory(scripts) +add_subdirectory(security) +add_subdirectory(services) diff --git a/src/starkware/starknet/apps/amm_sample/amm_sample.cairo b/src/starkware/starknet/apps/amm_sample/amm_sample.cairo new file mode 100644 index 00000000..847d1a99 --- /dev/null +++ b/src/starkware/starknet/apps/amm_sample/amm_sample.cairo @@ -0,0 +1,141 @@ +%lang starknet +%builtins pedersen range_check + +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.hash import hash2 +from starkware.cairo.common.math import assert_le, assert_nn_le, unsigned_div_rem +from starkware.starknet.core.storage.storage import Storage, storage_read, storage_write + +# The maximum amount of each token that belongs to the AMM. +const BALANCE_UPPER_BOUND = %[ 2**64 %] + +const TOKEN_TYPE_A = 1 +const TOKEN_TYPE_B = 2 + +# Ensure the user's balances are much smaller than the pool's balance. +const POOL_UPPER_BOUND = %[ 2**30 %] +const ACCOUNT_BALANCE_BOUND = %[ 2**30 // 1000 %] + +# A map from account and token type to the corresponding balance of that account. +@storage_var +func account_balance(account_id : felt, token_type : felt) -> (balance : felt): +end + +# A map from token type to the corresponding balance of the pool. +@storage_var +func pool_balance(token_type : felt) -> (balance : felt): +end + +# Adds amount to the account's balance for the given token. +# amount may be positive or negative. +# Assert before setting that the balance does not exceed the upper bound. +func modify_account_balance{storage_ptr : Storage*, pedersen_ptr : HashBuiltin*, range_check_ptr}( + account_id : felt, token_type : felt, amount : felt): + let (current_balance) = account_balance.read(account_id, token_type) + tempvar new_balance = current_balance + amount + assert_nn_le(new_balance, BALANCE_UPPER_BOUND - 1) + account_balance.write(account_id=account_id, token_type=token_type, value=new_balance) + return () +end + +# Returns the account's balance for the given token. +@view +func get_account_token_balance{storage_ptr : Storage*, pedersen_ptr : HashBuiltin*}( + account_id : felt, token_type : felt) -> (balance : felt): + return account_balance.read(account_id, token_type) +end + +# Sets the pool's balance for the given token. +# Asserts before setting that the balance does not exceed the upper bound. +func set_pool_token_balance{storage_ptr : Storage*, pedersen_ptr : HashBuiltin*, range_check_ptr}( + token_type : felt, balance : felt): + assert_nn_le(balance, BALANCE_UPPER_BOUND - 1) + pool_balance.write(token_type, balance) + return () +end + +# Returns the pool's balance. +@view +func get_pool_token_balance{storage_ptr : Storage*, pedersen_ptr : HashBuiltin*}( + token_type : felt) -> (balance : felt): + return pool_balance.read(token_type) +end + +# Swaps tokens between the given account and the pool. +func do_swap{storage_ptr : Storage*, pedersen_ptr : HashBuiltin*, range_check_ptr}( + account_id : felt, token_from : felt, token_to : felt, amount_from : felt) -> ( + amount_to : felt): + alloc_locals + + # Get pool balance. + let (local amm_from_balance) = get_pool_token_balance(token_type=token_from) + let (local amm_to_balance) = get_pool_token_balance(token_type=token_to) + + # Calculate swap amount. + let (local amount_to, _) = unsigned_div_rem( + amm_to_balance * amount_from, amm_from_balance + amount_from) + + # Update token_from balances. + modify_account_balance(account_id=account_id, token_type=token_from, amount=-amount_from) + set_pool_token_balance(token_type=token_from, balance=amm_from_balance + amount_from) + + # Update token_to balances. + modify_account_balance(account_id=account_id, token_type=token_to, amount=amount_to) + set_pool_token_balance(token_type=token_to, balance=amm_to_balance - amount_to) + return (amount_to=amount_to) +end + +func get_opposite_token(token_type : felt) -> (t : felt): + if token_type == TOKEN_TYPE_A: + return (TOKEN_TYPE_B) + else: + return (TOKEN_TYPE_A) + end +end + +# Swaps tokens between the given account and the pool. +@external +func swap{storage_ptr : Storage*, pedersen_ptr : HashBuiltin*, range_check_ptr}( + account_id : felt, token_from : felt, amount_from : felt) -> (amount_to : felt): + # Verify that token_from is either TOKEN_TYPE_A or TOKEN_TYPE_B. + assert (token_from - TOKEN_TYPE_A) * (token_from - TOKEN_TYPE_B) = 0 + + # Check requested amount_from is valid. + assert_nn_le(amount_from, BALANCE_UPPER_BOUND - 1) + # Check user has enough funds. + let (account_from_balance) = get_account_token_balance( + account_id=account_id, token_type=token_from) + assert_le(amount_from, account_from_balance) + + let (token_to) = get_opposite_token(token_type=token_from) + let (amount_to) = do_swap( + account_id=account_id, token_from=token_from, token_to=token_to, amount_from=amount_from) + + return (amount_to=amount_to) +end + +# Adds demo tokens to the given account. +@external +func add_demo_token{storage_ptr : Storage*, pedersen_ptr : HashBuiltin*, range_check_ptr}( + account_id : felt, token_a_amount : felt, token_b_amount : felt): + # Make sure the account's balance is much smaller then pool init balance. + assert_nn_le(token_a_amount, ACCOUNT_BALANCE_BOUND - 1) + assert_nn_le(token_b_amount, ACCOUNT_BALANCE_BOUND - 1) + + modify_account_balance(account_id=account_id, token_type=TOKEN_TYPE_A, amount=token_a_amount) + modify_account_balance(account_id=account_id, token_type=TOKEN_TYPE_B, amount=token_b_amount) + return () +end + +# Until we have LPs, for testing, we'll need to initialize the AMM somehow. +@external +func init_pool{storage_ptr : Storage*, pedersen_ptr : HashBuiltin*, range_check_ptr}( + token_a : felt, token_b : felt): + assert_nn_le(token_a, POOL_UPPER_BOUND - 1) + assert_nn_le(token_b, POOL_UPPER_BOUND - 1) + + set_pool_token_balance(token_type=TOKEN_TYPE_A, balance=token_a) + set_pool_token_balance(token_type=TOKEN_TYPE_B, balance=token_b) + + return () +end diff --git a/src/starkware/starknet/cli/CMakeLists.txt b/src/starkware/starknet/cli/CMakeLists.txt new file mode 100644 index 00000000..10db7171 --- /dev/null +++ b/src/starkware/starknet/cli/CMakeLists.txt @@ -0,0 +1,30 @@ +python_lib(starknet_cli_lib + PREFIX starkware/starknet/cli + + FILES + starknet_cli.py + + LIBS + cairo_compile_lib + cairo_version_lib + services_external_api_lib + starknet_compile_lib + starknet_contract_definition_lib + starknet_definitions_lib + starknet_feeder_gateway_client_lib + starknet_gateway_client_lib + starknet_transaction_lib + starkware_utils_lib +) + +python_venv(starknet_cli_venv + PYTHON python3.7 + + LIBS + starknet_cli_lib +) + +python_exe(starknet_cli + VENV starknet_cli_venv + MODULE starkware.starknet.cli.starknet_cli +) diff --git a/src/starkware/starknet/cli/starknet_cli.py b/src/starkware/starknet/cli/starknet_cli.py new file mode 100755 index 00000000..1f6b0267 --- /dev/null +++ b/src/starkware/starknet/cli/starknet_cli.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python3 + +import argparse +import asyncio +import functools +import json +import os +import sys + +from services.external_api.base_client import RetryConfig +from starkware.cairo.lang.version import __version__ +from starkware.starknet.compiler.compile import get_selector_from_name +from starkware.starknet.definitions import fields +from starkware.starknet.services.api.contract_definition import ContractDefinition +from starkware.starknet.services.api.feeder_gateway.feeder_gateway_client import FeederGatewayClient +from starkware.starknet.services.api.gateway.gateway_client import GatewayClient +from starkware.starknet.services.api.gateway.transaction import Deploy, InvokeFunction +from starkware.starkware_utils.error_handling import StarkErrorCode + + +def get_gateway_client(args) -> GatewayClient: + gateway_url = os.environ.get('STARKNET_GATEWAY_URL') + if args.gateway_url is not None: + gateway_url = args.gateway_url + if gateway_url is None: + raise Exception( + f'gateway_url must be specified with the "{args.command}" subcommand.\n' + 'Consider passing --network or setting the STARKNET_NETWORK environment variable.') + # Limit the number of retries. + retry_config = RetryConfig(n_retries=1) + return GatewayClient(url=gateway_url, retry_config=retry_config) + + +def get_feeder_gateway_client(args) -> FeederGatewayClient: + feeder_gateway_url = os.environ.get('STARKNET_FEEDER_GATEWAY_URL') + if args.feeder_gateway_url is not None: + feeder_gateway_url = args.feeder_gateway_url + if feeder_gateway_url is None: + raise Exception( + f'feeder_gateway_url must be specified with the "{args.command}" subcommand.\n' + 'Consider passing --network or setting the STARKNET_NETWORK environment variable.') + # Limit the number of retries. + retry_config = RetryConfig(n_retries=1) + return FeederGatewayClient(url=feeder_gateway_url, retry_config=retry_config) + + +async def deploy(args, command_args): + parser = argparse.ArgumentParser( + description='Sends a deploy transaction to StarkNet.') + parser.add_argument( + '--address', type=str, + help='An optional address specifying where the contract will be deployed. ' + 'If the address is not specified, the contract will be deployed in a random address.') + parser.add_argument( + '--contract', type=argparse.FileType('r'), + help='The contract definition to deploy.', required=True) + parser.parse_args(command_args, namespace=args) + + gateway_client = get_gateway_client(args) + + try: + address = fields.ContractAddressField.get_random_value() if args.address is None \ + else int(args.address, 16) + except ValueError: + raise ValueError('Invalid address format.') + + contract_definition = ContractDefinition.loads(args.contract.read()) + tx = Deploy( + contract_address=address, + contract_definition=contract_definition) + + gateway_response = await gateway_client.add_transaction(tx=tx) + assert gateway_response['code'] == StarkErrorCode.TRANSACTION_RECEIVED.name, \ + f'Failed to send transaction. Response: {gateway_response}.' + print(f"""\ +Deploy transaction was sent. +Contract address: 0x{address:064x}. +Transaction ID: {gateway_response['tx_id']}.""") + + +async def invoke_or_call(args, command_args, call: bool): + parser = argparse.ArgumentParser( + description='Sends an invoke transaction to StarkNet.') + parser.add_argument( + '--address', type=str, required=True, help='The address of the invoked contract.') + parser.add_argument( + '--abi', type=argparse.FileType('r'), required=True, help='The Cairo contract ABI.') + parser.add_argument( + '--function', type=str, required=True, help='The name of the invoked function.') + parser.add_argument( + '--inputs', type=int, nargs='*', default=[], help='The inputs to the invoked function.') + if call: + parser.add_argument( + '--block_id', type=int, required=False, + help='The ID of the block used as the context for the call operation. ' + 'In case this argument is not given, uses the latest block.') + parser.parse_args(command_args, namespace=args) + + abi = json.load(args.abi) + try: + address = int(args.address, 16) + except ValueError: + raise ValueError('Invalid address format.') + for abi_entry in abi: + if abi_entry['type'] == 'function' and abi_entry['name'] == args.function: + break + else: + raise Exception(f'Function {args.function} not found.') + selector = get_selector_from_name(args.function) + assert len(args.inputs) == len(abi_entry['inputs']), \ + f'Wrong number of arguments. Expected {len(abi_entry["inputs"])}, got {len(args.inputs)}.' + calldata = args.inputs + + tx = InvokeFunction( + contract_address=address, + entry_point_selector=selector, + calldata=calldata) + + gateway_response: dict + if call: + feeder_client = get_feeder_gateway_client(args) + gateway_response = await feeder_client.call_contract(tx, args.block_id) + print(*gateway_response['result']) + else: + gateway_client = get_gateway_client(args) + gateway_response = await gateway_client.add_transaction(tx=tx) + assert gateway_response['code'] == StarkErrorCode.TRANSACTION_RECEIVED.name, \ + f'Failed to send transaction. Response: {gateway_response}.' + print(f"""\ +Invoke transaction was sent. +Contract address: 0x{address:064x}. +Transaction ID: {gateway_response['tx_id']}.""") + + +async def tx_status(args, command_args): + parser = argparse.ArgumentParser( + description='Queries the status of a transaction given its ID.') + parser.add_argument( + '--id', type=int, help='The ID of the transaction to query.', required=True) + parser.parse_args(command_args, namespace=args) + + feeder_gateway_client = get_feeder_gateway_client(args) + + tx_status_response = await feeder_gateway_client.get_transaction_status(tx_id=args.id) + print(json.dumps(tx_status_response, indent=4, sort_keys=True)) + + +def handle_network_param(args): + """ + Gives default values to the gateways if the network parameter is set. + """ + network = os.environ.get('STARKNET_NETWORK') if args.network is None else args.network + if network is not None: + if network != 'alpha': + print(f"Unknown network '{network}'.") + return 1 + + dns = 'alpha.starknet.io' + if args.gateway_url is None: + args.gateway_url = f'https://{dns}/gateway' + + if args.feeder_gateway_url is None: + args.feeder_gateway_url = f'https://{dns}/feeder_gateway' + + return 0 + + +async def get_block(args, command_args): + parser = argparse.ArgumentParser( + description='Outputs the block corresponding to the given ID. ' + 'In case no ID is given, outputs the latest block.') + parser.add_argument( + '--id', type=int, + help='The ID of the block to display. In case this argument is not given, uses the latest ' + 'block.') + parser.parse_args(command_args, namespace=args) + + feeder_gateway_client = get_feeder_gateway_client(args) + + block_as_dict = await feeder_gateway_client.get_block(block_id=args.id) + print(json.dumps(block_as_dict, indent=4, sort_keys=True)) + + +async def get_code(args, command_args): + parser = argparse.ArgumentParser( + description='Outputs the bytecode of the contract at the given address with respect to ' + 'a specific block. In case no block ID is given, uses the latest block.') + parser.add_argument( + '--contract_address', type=str, help='The address of the contract.', required=True) + parser.add_argument( + '--block_id', type=int, + help='The ID of the block to extract information from. ' + 'In case this argument is not given, uses the latest block.') + parser.parse_args(command_args, namespace=args) + + feeder_gateway_client = get_feeder_gateway_client(args) + + code = await feeder_gateway_client.get_code( + contract_address=int(args.contract_address, 16), block_id=args.block_id) + print(json.dumps(code, indent=4, sort_keys=True)) + + +async def get_storage_at(args, command_args): + parser = argparse.ArgumentParser( + description='Outputs the storage value of a contract in a specific key with respect to ' + 'a specific block. In case no block ID is given, uses the latest block.') + parser.add_argument( + '--contract_address', type=str, help='The address of the contract.', required=True) + parser.add_argument( + '--key', type=int, help="The position in the contract's storage.", required=True) + parser.add_argument( + '--block_id', type=int, + help='The ID of the block to extract information from. ' + 'In case this argument is not given, uses the latest block.') + parser.parse_args(command_args, namespace=args) + + feeder_gateway_client = get_feeder_gateway_client(args) + + print( + await feeder_gateway_client.get_storage_at( + contract_address=int(args.contract_address, 16), key=args.key, block_id=args.block_id) + ) + + +async def main(): + subparsers = { + 'deploy': deploy, + 'invoke': functools.partial(invoke_or_call, call=False), + 'call': functools.partial(invoke_or_call, call=True), + 'tx_status': tx_status, + 'get_block': get_block, + 'get_code': get_code, + 'get_storage_at': get_storage_at, + } + parser = argparse.ArgumentParser(description='A tool to communicate with StarkNet.') + parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}') + parser.add_argument('--network', type=str, help='The name of Network.') + + parser.add_argument('--gateway_url', type=str, help='The URL of a StarkNet gateway.') + parser.add_argument( + '--feeder_gateway_url', type=str, help='The URL of a StarkNet feeder gateway.') + parser.add_argument('command', choices=subparsers.keys()) + + args, unknown = parser.parse_known_args() + + ret = handle_network_param(args) + if ret != 0: + return ret + + try: + # Invoke the requested command. + return await subparsers[args.command](args, unknown) + except Exception as exc: + print(f'Error: {type(exc).__name__}: {exc}', file=sys.stderr) + return 1 + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/starkware/starknet/compiler/CMakeLists.txt b/src/starkware/starknet/compiler/CMakeLists.txt new file mode 100644 index 00000000..11764929 --- /dev/null +++ b/src/starkware/starknet/compiler/CMakeLists.txt @@ -0,0 +1,53 @@ +if (NOT DEFINED CAIRO_PYTHON_INTERPRETER) + set(CAIRO_PYTHON_INTERPRETER python3.7) +endif() + +python_lib(starknet_compile_lib + PREFIX starkware/starknet/compiler + + FILES + calldata_parser.py + compile.py + starknet_pass_manager.py + starknet_preprocessor.py + storage_var.py + + LIBS + cairo_compile_lib + cairo_constants_lib + starknet_abi_lib + starknet_cairo_storage_lib + starknet_contract_definition_lib + starknet_definitions_lib + starknet_hints_whitelist_lib + starknet_security_lib +) + +full_python_test(starknet_compile_test + PYTHON ${CAIRO_PYTHON_INTERPRETER} + TESTED_MODULES starkware/starknet/compiler + PREFIX starkware/starknet/compiler + + FILES + calldata_parser_test.py + starknet_preprocessor_test.py + storage_var_test.py + test_utils.py + + LIBS + cairo_compile_test_utils_lib + starknet_compile_lib + pip_pytest +) + +python_venv(starknet_compile_venv + PYTHON ${CAIRO_PYTHON_INTERPRETER} + + LIBS + starknet_compile_lib +) + +python_exe(starknet_compile_exe + VENV starknet_compile_venv + MODULE starkware.starknet.compiler.compile +) diff --git a/src/starkware/starknet/compiler/__init__.py b/src/starkware/starknet/compiler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/starkware/starknet/compiler/calldata_parser.py b/src/starkware/starknet/compiler/calldata_parser.py new file mode 100644 index 00000000..73edcd63 --- /dev/null +++ b/src/starkware/starknet/compiler/calldata_parser.py @@ -0,0 +1,42 @@ +from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt +from starkware.cairo.lang.compiler.ast.expr import ( + ArgList, ExprAssignment, ExprCast, ExprConst, ExprDeref, Expression, ExprIdentifier, + ExprOperator) +from starkware.cairo.lang.compiler.ast.notes import Notes +from starkware.cairo.lang.compiler.identifier_definition import StructDefinition +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager +from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError + + +def process_calldata( + calldata_ptr: Expression, identifiers: IdentifierManager, + struct_def: StructDefinition) -> ArgList: + """ + Processes the calldata and produces an ArgList that corresponds to 'struct_def'. + + Currently only the trivial case where struct consists only of felts is supported. + """ + args = [] + for member_name, member_def in struct_def.members.items(): + location = member_def.location + cairo_type = member_def.cairo_type + if not isinstance(cairo_type, TypeFelt): + raise PreprocessorError( + f'Unsupported argument type {cairo_type.format()}.', + location=cairo_type.location) + + args.append(ExprAssignment( + identifier=ExprIdentifier(name=member_name, location=member_def.location), + expr=ExprCast( + expr=ExprDeref( + addr=ExprOperator( + calldata_ptr, '+', ExprConst(member_def.offset, location=location), + location=location), + location=location), + dest_type=cairo_type, + location=cairo_type.location), + location=struct_def.location)) + + return ArgList( + args=args, notes=[Notes()] * (len(args) + 1), has_trailing_comma=True, + location=struct_def.location) diff --git a/src/starkware/starknet/compiler/calldata_parser_test.py b/src/starkware/starknet/compiler/calldata_parser_test.py new file mode 100644 index 00000000..b0f2a0b0 --- /dev/null +++ b/src/starkware/starknet/compiler/calldata_parser_test.py @@ -0,0 +1,59 @@ +import pytest + +from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt +from starkware.cairo.lang.compiler.ast.expr import ExprIdentifier +from starkware.cairo.lang.compiler.identifier_definition import MemberDefinition, StructDefinition +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager +from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition +from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError +from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.cairo.lang.compiler.type_casts import FELT_STAR +from starkware.starknet.compiler.calldata_parser import process_calldata + +scope = ScopedName.from_string + + +def test_process_calldata_flow(): + identifier_values = { + scope('MyStruct'): StructDefinition( + full_name=scope('MyStruct'), + members={ + 'arg_a': MemberDefinition(offset=0, cairo_type=TypeFelt()), + 'arg_b': MemberDefinition(offset=1, cairo_type=TypeFelt()), + }, + size=11, + ), + } + identifiers = IdentifierManager.from_dict(identifier_values) + + calldata_ptr = ExprIdentifier('calldata_ptr') + + expr = process_calldata( + calldata_ptr=calldata_ptr, identifiers=identifiers, + struct_def=get_struct_definition( + struct_name=scope('MyStruct'), identifier_manager=identifiers)) + + assert expr.format() == """\ +arg_a=cast([calldata_ptr + 0], felt), arg_b=cast([calldata_ptr + 1], felt),""" + + +def test_process_calldata_failure(): + identifier_values = { + scope('MyStruct'): StructDefinition( + full_name=scope('MyStruct'), + members={ + 'arg_a': MemberDefinition(offset=0, cairo_type=FELT_STAR), + 'arg_b': MemberDefinition(offset=1, cairo_type=TypeFelt()), + }, + size=11, + ), + } + identifiers = IdentifierManager.from_dict(identifier_values) + + calldata_ptr = ExprIdentifier('calldata_ptr') + + with pytest.raises(PreprocessorError, match='Unsupported argument type felt*.'): + process_calldata( + calldata_ptr=calldata_ptr, identifiers=identifiers, + struct_def=get_struct_definition( + struct_name=scope('MyStruct'), identifier_manager=identifiers)) diff --git a/src/starkware/starknet/compiler/compile.py b/src/starkware/starknet/compiler/compile.py new file mode 100644 index 00000000..19856626 --- /dev/null +++ b/src/starkware/starknet/compiler/compile.py @@ -0,0 +1,106 @@ +import argparse +import json +import sys +from typing import Dict + +from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME +from starkware.cairo.lang.compiler.assembler import assemble +from starkware.cairo.lang.compiler.cairo_compile import ( + cairo_compile_add_common_args, cairo_compile_common, compile_cairo_ex, get_codes, + get_module_reader) +from starkware.cairo.lang.compiler.error_handling import LocationError +from starkware.cairo.lang.compiler.identifier_definition import FunctionDefinition +from starkware.cairo.lang.compiler.module_reader import ModuleReader +from starkware.cairo.lang.compiler.preprocessor.pass_manager import PassManager +from starkware.cairo.lang.compiler.program import Program +from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.starknet.compiler.starknet_pass_manager import starknet_pass_manager +from starkware.starknet.compiler.starknet_preprocessor import ( + WRAPPER_SCOPE, StarknetPreprocessedProgram) +from starkware.starknet.public.abi import starknet_keccak +from starkware.starknet.services.api.contract_definition import ( + ContractDefinition, ContractEntryPoint) + + +def get_selector_from_name(func_name: str) -> int: + return starknet_keccak(data=func_name.encode('ascii')) + + +def get_entry_points(program: Program) -> Dict[str, ContractEntryPoint]: + """ + Returns a mapping from entry point name to (selector, offset). + """ + wrapper_scope = program.identifiers.get_scope(WRAPPER_SCOPE) + return { + func_name: ContractEntryPoint( + selector=get_selector_from_name(func_name), + offset=func_def.pc) + for func_name, func_def in wrapper_scope.identifiers.items() + if isinstance(func_def, FunctionDefinition)} + + +def compile_starknet_files( + files, debug_info: bool = False, + disable_hint_validation: bool = False) -> ContractDefinition: + module_reader = get_module_reader(cairo_path=[]) + + pass_manager = starknet_pass_manager( + prime=DEFAULT_PRIME, read_module=module_reader.read, + disable_hint_validation=disable_hint_validation) + + program, preprocessed = compile_cairo_ex( + code=get_codes(files), debug_info=debug_info, pass_manager=pass_manager) + + # Dump and load program, so that it is converted to the canonical form. + program_schema = program.Schema() + program = program_schema.load(data=program_schema.dump(obj=program)) + + assert isinstance(preprocessed, StarknetPreprocessedProgram) + return ContractDefinition( + program=program, entry_points=list(get_entry_points(program=program).values()), + abi=preprocessed.abi) + + +def assemble_starknet_contract( + preprocessed_program: StarknetPreprocessedProgram, main_scope: ScopedName, + add_debug_info: bool, file_contents_for_debug_info: Dict[str, str]) -> ContractDefinition: + assert isinstance(preprocessed_program, StarknetPreprocessedProgram) + program = assemble( + preprocessed_program, main_scope=main_scope, add_debug_info=add_debug_info, + file_contents_for_debug_info=file_contents_for_debug_info) + + return ContractDefinition( + program=program, entry_points=list(get_entry_points(program=program).values()), + abi=preprocessed_program.abi) + + +def main(): + parser = argparse.ArgumentParser(description='A tool to compile StarkNet contracts.') + parser.add_argument('--abi', type=argparse.FileType('w'), help="Output the contract's ABI.") + parser.add_argument( + '--disable_hint_validation', action='store_true', help='Disable the hint validation.') + + def pass_manager_factory(args: argparse.Namespace, module_reader: ModuleReader) -> PassManager: + return starknet_pass_manager( + prime=args.prime, + read_module=module_reader.read, + disable_hint_validation=args.disable_hint_validation) + + try: + cairo_compile_add_common_args(parser) + args = parser.parse_args() + preprocessed = cairo_compile_common( + args=args, pass_manager_factory=pass_manager_factory, + assemble_func=assemble_starknet_contract) + assert isinstance(preprocessed, StarknetPreprocessedProgram) + if args.abi is not None: + json.dump(preprocessed.abi, args.abi, indent=4, sort_keys=True) + args.abi.write('\n') + except LocationError as err: + print(err, file=sys.stderr) + return 1 + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/src/starkware/starknet/compiler/conftest.py b/src/starkware/starknet/compiler/conftest.py new file mode 100644 index 00000000..98c3adc3 --- /dev/null +++ b/src/starkware/starknet/compiler/conftest.py @@ -0,0 +1,3 @@ +import pytest + +pytest.register_assert_rewrite('starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils') diff --git a/src/starkware/starknet/compiler/starknet_pass_manager.py b/src/starkware/starknet/compiler/starknet_pass_manager.py new file mode 100644 index 00000000..fbd28118 --- /dev/null +++ b/src/starkware/starknet/compiler/starknet_pass_manager.py @@ -0,0 +1,38 @@ +from typing import Callable, Tuple + +from starkware.cairo.lang.compiler.preprocessor.default_pass_manager import ( + ModuleCollector, default_pass_manager) +from starkware.cairo.lang.compiler.preprocessor.pass_manager import PassManager, VisitorStage +from starkware.starknet.compiler.starknet_preprocessor import StarknetPreprocessor +from starkware.starknet.compiler.storage_var import ( + StorageVarDeclVisitor, StorageVarImplentationVisitor) +from starkware.starknet.security.hints_whitelist import get_hints_whitelist + + +def starknet_pass_manager( + prime: int, read_module: Callable[[str], Tuple[str, str]], + opt_unused_functions: bool = True, disable_hint_validation: bool = False) -> PassManager: + hint_whitelist = None if disable_hint_validation else get_hints_whitelist() + manager = default_pass_manager( + prime=prime, read_module=read_module, preprocessor_cls=StarknetPreprocessor, + opt_unused_functions=opt_unused_functions, + preprocessor_kwargs=dict(hint_whitelist=hint_whitelist)) + # Use ModuleCollector.additional_modules to import necessary modules, whose import line + # may be added after the module_collector phase. + manager.replace('module_collector', ModuleCollector( + read_module=read_module, + additional_modules=[ + 'starkware.cairo.common.cairo_builtins', + 'starkware.cairo.common.hash', + 'starkware.starknet.core.storage.storage', + ])) + + manager.add_before( + existing_stage='identifier_collector', + new_stage_name='storage_var_signature', + new_stage=VisitorStage(StorageVarDeclVisitor, modify_ast=True)) + manager.add_after( + existing_stage='struct_collector', + new_stage_name='storage_var_implementation', + new_stage=VisitorStage(StorageVarImplentationVisitor, modify_ast=True)) + return manager diff --git a/src/starkware/starknet/compiler/starknet_preprocessor.py b/src/starkware/starknet/compiler/starknet_preprocessor.py new file mode 100644 index 00000000..6578b3e5 --- /dev/null +++ b/src/starkware/starknet/compiler/starknet_preprocessor.py @@ -0,0 +1,394 @@ +import dataclasses +from typing import Any, Dict, List, Optional, Tuple + +from starkware.cairo.lang.compiler.ast.cairo_types import ( + CairoType, TypeFelt, TypePointer, TypeStruct) +from starkware.cairo.lang.compiler.ast.code_elements import ( + BuiltinsDirective, CodeElementCompoundAssertEq, CodeElementFuncCall, CodeElementFunction, + CodeElementHint, CodeElementInstruction, LangDirective) +from starkware.cairo.lang.compiler.ast.expr import ( + ArgList, ExprAssignment, ExprCast, ExprConst, ExprDeref, Expression, ExprIdentifier, + ExprOperator, ExprReg) +from starkware.cairo.lang.compiler.ast.instructions import ( + AddApInstruction, InstructionAst, RetInstruction) +from starkware.cairo.lang.compiler.ast.rvalue import RvalueFuncCall +from starkware.cairo.lang.compiler.ast.types import TypedIdentifier +from starkware.cairo.lang.compiler.error_handling import Location +from starkware.cairo.lang.compiler.identifier_definition import ( + AliasDefinition, FunctionDefinition, FutureIdentifierDefinition, StructDefinition) +from starkware.cairo.lang.compiler.instruction import Register +from starkware.cairo.lang.compiler.preprocessor.preprocessor import ( + PreprocessedProgram, Preprocessor) +from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError +from starkware.cairo.lang.compiler.program import CairoHint +from starkware.cairo.lang.compiler.references import create_simple_ref_expr +from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.starknet.compiler.calldata_parser import process_calldata +from starkware.starknet.definitions.constants import STARKNET_LANG_DIRECTIVE +from starkware.starknet.security.secure_hints import HintsWhitelist, InsecureHintError + +EXTERNAL_DECORATOR = 'external' +VIEW_DECORATOR = 'view' +WRAPPER_SCOPE = ScopedName.from_string('__wrappers__') + + +@dataclasses.dataclass +class StarknetPreprocessedProgram(PreprocessedProgram): + # JSON dict that contains information on the callable functions in the contract. + abi: Any + + +class StarknetPreprocessor(Preprocessor): + def __init__(self, **kwargs): + kwargs = dict(kwargs) + supported_decorators = kwargs.pop('supported_decorators', { + EXTERNAL_DECORATOR, VIEW_DECORATOR}) + + # A whitelist of allowed hints. + # None means that any hint is allowed. + self.hint_whitelist: Optional[HintsWhitelist] = kwargs.pop('hint_whitelist', None) + + super().__init__(supported_decorators=supported_decorators, **kwargs) + + # A mapping from name to offset in the os_context that is passed to the contract. + # Unfortunately we need to process the builtins directive before we can initialize it. + self.os_context: Optional[Dict[str, int]] = None + # JSON dict for the ABI output. + self.abi: List[dict] = [] + + def get_external_decorator(self, elm: CodeElementFunction) -> Optional[ExprIdentifier]: + """ + If the function has the @external or @view decorator, returns it. + Otherwise, returns None. + """ + for decorator in elm.decorators: + if decorator.name in [EXTERNAL_DECORATOR, VIEW_DECORATOR]: + return decorator + + return None + + def visit_BuiltinsDirective(self, directive: BuiltinsDirective): + super().visit_BuiltinsDirective(directive) + assert self.builtins is not None + if 'storage' in self.builtins: + raise PreprocessorError( + "'storage' may not appear in the builtins directive.", + location=directive.location) + + def visit_LangDirective(self, directive: LangDirective): + if directive.name != STARKNET_LANG_DIRECTIVE: + raise PreprocessorError( + f'Unsupported %lang directive. Are you using the correct compiler?', + location=directive.location, + ) + + def get_os_context(self) -> Dict[str, int]: + if self.os_context is None: + builtins = [] if self.builtins is None else self.builtins + + os_context = {'storage_ptr': 0} + for index, builtin_name in enumerate(builtins, len(os_context)): + ptr_name = f'{builtin_name}_ptr' + assert os_context.setdefault(ptr_name, index) == index, \ + f'os_context.{ptr_name} was redefined.' + + self.os_context = os_context + return self.os_context + + def create_func_wrapper(self, elm: CodeElementFunction, func_alias_name: str): + """ + Generates a wrapper that converts between the StarkNet contract ABI and the + Cairo calling convention. + + Arguments: + elm - the CodeElementFunction of the wrapped function. + func_alias_name - an alias for the FunctionDefention in the current scope. + """ + + os_context = self.get_os_context() + + func_location = elm.identifier.location + + # We expect the call stack to look as follows: + # pointer to builtins struct. + # pointer to the call data array. + # ret_fp. + # ret_pc. + builtins_ptr = ExprDeref( + addr=ExprOperator( + ExprReg(reg=Register.FP, location=func_location), + '+', + ExprConst(-4, location=func_location), + location=func_location), + location=func_location) + calldata_ptr = ExprDeref( + addr=ExprOperator( + ExprReg(reg=Register.FP, location=func_location), + '+', + ExprConst(-3, location=func_location), + location=func_location), + location=func_location) + + implicit_arguments = None + + implicit_arguments_identifiers: Dict[str, TypedIdentifier] = {} + if elm.implicit_arguments is not None: + args = [] + for typed_identifier in elm.implicit_arguments.identifiers: + ptr_name = typed_identifier.name + if ptr_name not in os_context: + raise PreprocessorError( + f"Unexpected implicit argument '{ptr_name}' in an external function.", + location=typed_identifier.identifier.location) + + implicit_arguments_identifiers[ptr_name] = typed_identifier + + # Add the assignment expression 'ptr_name = ptr_name' to the implicit arg list. + args.append(ExprAssignment( + identifier=typed_identifier.identifier, + expr=typed_identifier.identifier, + location=typed_identifier.location, + )) + + implicit_arguments = ArgList( + args=args, notes=[], has_trailing_comma=True, + location=elm.implicit_arguments.location) + + return_args_exprs: List[Expression] = [] + + # Create references. + for ptr_name, index in os_context.items(): + ref_name = self.current_scope + ptr_name + + arg_identifier = implicit_arguments_identifiers.get(ptr_name) + if arg_identifier is None: + location = func_location + cairo_type: CairoType = TypeFelt(location=location) + else: + location = arg_identifier.location + cairo_type = self.resolve_type(arg_identifier.get_type()) + + # Add a reference of the form + # 'let ref_name = [cast(builtins_ptr + index, cairo_type*)]'. + self.add_reference( + name=ref_name, + value=ExprDeref( + addr=ExprCast( + ExprOperator( + builtins_ptr, '+', ExprConst(index, location=location), + location=location), + dest_type=TypePointer(pointee=cairo_type, location=cairo_type.location), + location=cairo_type.location), + location=location), + cairo_type=cairo_type, + location=location, + require_future_definition=False) + + assert index == len(return_args_exprs), 'Unexpected index.' + + return_args_exprs.append(ExprIdentifier(name=ptr_name, location=func_location)) + + arg_struct_def = self.get_struct_definition( + name=ScopedName.from_string(func_alias_name) + CodeElementFunction.ARGUMENT_SCOPE, + location=func_location) + self.visit(CodeElementFuncCall( + func_call=RvalueFuncCall( + func_ident=ExprIdentifier(name=func_alias_name, location=func_location), + arguments=process_calldata( + calldata_ptr=calldata_ptr, + identifiers=self.identifiers, + struct_def=arg_struct_def + ), + implicit_arguments=implicit_arguments, + location=func_location))) + + ret_struct_name = ScopedName.from_string(func_alias_name) + CodeElementFunction.RETURN_SCOPE + ret_struct_type = self.resolve_type(TypeStruct(ret_struct_name, False)) + ret_struct_def = self.get_struct_definition( + name=ret_struct_name, + location=func_location) + ret_struct_expr = create_simple_ref_expr( + reg=Register.AP, offset=-ret_struct_def.size, cairo_type=ret_struct_type, + location=func_location) + self.add_reference( + name=self.current_scope + 'ret_struct', + value=ret_struct_expr, + cairo_type=ret_struct_type, + require_future_definition=False, + location=func_location) + + # Add function return values. + retdata_size, retdata_ptr = self.process_retdata( + ret_struct_ptr=ExprIdentifier(name='ret_struct'), + ret_struct_type=ret_struct_type, struct_def=ret_struct_def, + location=func_location, + ) + return_args_exprs += [retdata_size, retdata_ptr] + + # Push the return values. + self.push_compound_expressions( + compound_expressions=[self.simplify_expr_as_felt(expr) for expr in return_args_exprs], + location=func_location, + ) + + # Add a ret instruction. + self.visit(CodeElementInstruction( + instruction=InstructionAst( + body=RetInstruction(), + inc_ap=False, + location=func_location))) + + # Add an entry to the ABI. + external_decorator = self.get_external_decorator(elm) + assert external_decorator is not None + is_view = external_decorator.name == 'view' + self.add_abi_entry( + name=elm.name, arg_struct_def=arg_struct_def, ret_struct_def=ret_struct_def, + is_view=is_view) + + def add_abi_entry( + self, name: str, arg_struct_def: StructDefinition, ret_struct_def: StructDefinition, + is_view: bool): + """ + Adds an entry describing the function to the contract's ABI. + """ + inputs = [] + outputs = [] + for m_name, member in arg_struct_def.members.items(): + assert isinstance(member.cairo_type, TypeFelt) + inputs.append({ + 'name': m_name, + 'type': 'felt', + }) + for m_name, member in ret_struct_def.members.items(): + assert isinstance(member.cairo_type, TypeFelt) + outputs.append({ + 'name': m_name, + 'type': 'felt', + }) + res = { + 'name': name, + 'type': 'function', + 'inputs': inputs, + 'outputs': outputs, + } + if is_view: + res['stateMutability'] = 'view' + self.abi.append(res) + + def get_program(self) -> StarknetPreprocessedProgram: + program = super().get_program() + return StarknetPreprocessedProgram( # type: ignore + **program.__dict__, + abi=self.abi, + ) + + def process_retdata( + self, ret_struct_ptr: Expression, ret_struct_type: CairoType, + struct_def: StructDefinition, + location: Optional[Location]) -> Tuple[Expression, Expression]: + """ + Processes the return values and return retdata_size and retdata_ptr. + """ + + # Verify all of the return types are felts. + for _, member_def in struct_def.members.items(): + cairo_type = member_def.cairo_type + if not isinstance(cairo_type, TypeFelt): + raise PreprocessorError( + f'Unsupported argument type {cairo_type.format()}.', + location=cairo_type.location) + + self.add_reference( + name=self.current_scope + 'retdata_ptr', + value=ExprDeref( + addr=ExprReg(reg=Register.AP), + location=location, + ), + cairo_type=TypePointer(TypeFelt()), + require_future_definition=False, + location=location) + + self.visit(CodeElementHint( + hint_code='memory[ap] = segments.add()', n_prefix_newlines=0, location=location)) + + # Skip check of hint whitelist as it fails before the workaround below. + super().visit_CodeElementInstruction(CodeElementInstruction(InstructionAst( + body=AddApInstruction(ExprConst(1)), + inc_ap=False, + location=location))) + + # Remove the references from the last instruction's flow tracking as they are + # not needed by the hint and they cause the hint whitelist to fail. + self.instructions[-1].flow_tracking_data = dataclasses.replace( + self.instructions[-1].flow_tracking_data, reference_ids={}) + self.visit(CodeElementCompoundAssertEq( + ExprDeref( + ExprCast(ExprIdentifier('retdata_ptr'), TypePointer(ret_struct_type))), + ret_struct_ptr)) + + return (ExprConst(struct_def.size), ExprIdentifier('retdata_ptr')) + + def visit_CodeElementFunction(self, elm: CodeElementFunction): + super().visit_CodeElementFunction(elm) + + external_decorator = self.get_external_decorator(elm) + if external_decorator is None: + return + + location = elm.identifier.location + + # Retrieve the canonical name of the function before switching scopes. + _, func_canonical_name = self.get_label(elm.name, location=location) + assert func_canonical_name is not None + + self.flow_tracking.revoke() + with self.scoped(WRAPPER_SCOPE, parent=elm), self.set_reference_states({}): + current_wrapper_scope = self.current_scope + elm.name + + self.add_name_definition( + current_wrapper_scope, + FunctionDefinition( # type: ignore + pc=self.current_pc, + decorators=[identifier.name for identifier in elm.decorators], + ), + location=elm.identifier.location, + require_future_definition=False) + + with self.scoped(current_wrapper_scope, parent=elm): + # Generate an alias that will allow us to call the original function. + func_alias_name = f'__wrapped_func' + alias_canonical_name = current_wrapper_scope + func_alias_name + self.add_future_definition( + name=alias_canonical_name, + future_definition=FutureIdentifierDefinition( + identifier_type=AliasDefinition), + ) + + self.add_name_definition( + name=alias_canonical_name, + identifier_definition=AliasDefinition(destination=func_canonical_name), + location=location) + + self.create_func_wrapper(elm=elm, func_alias_name=func_alias_name) + + def visit_CodeElementInstruction(self, elm: CodeElementInstruction): + if self.hint_whitelist is not None: + hint = self.next_instruction_hint + if hint is not None: + try: + self.hint_whitelist.verify_hint_secure( + hint=CairoHint( + code=hint.hint_code, + accessible_scopes=self.accessible_scopes, + flow_tracking_data=self.flow_tracking.get(), + ), + reference_manager=self.flow_tracking.reference_manager) + except InsecureHintError: + raise PreprocessorError( + """\ +Hint is not whitelisted. +This may indicate that this library function cannot be used in StarkNet contracts.""", + location=hint.location) + + super().visit_CodeElementInstruction(elm) diff --git a/src/starkware/starknet/compiler/starknet_preprocessor_test.py b/src/starkware/starknet/compiler/starknet_preprocessor_test.py new file mode 100644 index 00000000..31b1d1be --- /dev/null +++ b/src/starkware/starknet/compiler/starknet_preprocessor_test.py @@ -0,0 +1,244 @@ +from starkware.cairo.lang.compiler.identifier_definition import FunctionDefinition +from starkware.starknet.compiler.starknet_preprocessor import WRAPPER_SCOPE +from starkware.starknet.compiler.test_utils import preprocess_str, verify_exception + + +def test_builtin_directive_after_external(): + verify_exception(""" +@external +func f{}(): + return() +end +%builtins pedersen range_check ecdsa +""", """ +file:?:?: Directives must appear at the top of the file. +%builtins pedersen range_check ecdsa +^**********************************^ +""") + + +def test_storage_in_builtin_directive(): + verify_exception(""" +%builtins storage +""", """ +file:?:?: 'storage' may not appear in the builtins directive. +%builtins storage +^***************^ +""") + + +def test_lang_directive(): + verify_exception(""" +%lang abc +""", """ +file:?:?: Unsupported %lang directive. Are you using the correct compiler? +%lang abc +^*******^ +""") + + +def test_bad_implicit_arg_name(): + verify_exception(""" +%builtins pedersen range_check ecdsa +@external +func f{hello}(): + return() +end +""", """ +file:?:?: Unexpected implicit argument 'hello' in an external function. +func f{hello}(): + ^***^ +""") + + +def test_wrapper_with_implicit_args(): + program = preprocess_str(""" +%builtins pedersen range_check ecdsa + +struct HashBuiltin: +end + +@external +func f{ecdsa_ptr, pedersen_ptr : HashBuiltin*}(a : felt, b : felt): + return () +end +""") + + assert isinstance(program.identifiers.get_by_full_name( + WRAPPER_SCOPE + 'f'), FunctionDefinition) + + assert program.format() == """\ +%builtins pedersen range_check ecdsa + +[ap] = [fp + (-6)]; ap++ +[ap] = [fp + (-5)]; ap++ +ret +[ap] = [[fp + (-4)] + 3]; ap++ +[ap] = [[fp + (-4)] + 1]; ap++ +[ap] = [[fp + (-3)]]; ap++ +[ap] = [[fp + (-3)] + 1]; ap++ +call rel -7 +%{ memory[ap] = segments.add() %} +ap += 1 +[ap] = [[fp + (-4)]]; ap++ +[ap] = [ap + (-3)]; ap++ +[ap] = [[fp + (-4)] + 2]; ap++ +[ap] = [ap + (-6)]; ap++ +[ap] = 0; ap++ +[ap] = [ap + (-6)]; ap++ +ret +""" + + +def test_wrapper_with_return_args(): + program = preprocess_str(""" +%builtins pedersen range_check ecdsa + +struct HashBuiltin: +end + +@external +func f{ecdsa_ptr}(a : felt, b : felt) -> (c : felt, d : felt): + return (c=1, d=2) +end +""") + + assert isinstance(program.identifiers.get_by_full_name( + WRAPPER_SCOPE + 'f'), FunctionDefinition) + + assert program.format() == """\ +%builtins pedersen range_check ecdsa + +[ap] = [fp + (-5)]; ap++ +[ap] = 1; ap++ +[ap] = 2; ap++ +ret +[ap] = [[fp + (-4)] + 3]; ap++ +[ap] = [[fp + (-3)]]; ap++ +[ap] = [[fp + (-3)] + 1]; ap++ +call rel -9 +%{ memory[ap] = segments.add() %} +ap += 1 +[[ap + (-1)]] = [ap + (-3)] +[[ap + (-1)] + 1] = [ap + (-2)] +[ap] = [[fp + (-4)]]; ap++ +[ap] = [[fp + (-4)] + 1]; ap++ +[ap] = [[fp + (-4)] + 2]; ap++ +[ap] = [ap + (-7)]; ap++ +[ap] = 2; ap++ +[ap] = [ap + (-6)]; ap++ +ret +""" + + +def test_wrapper_without_implicit_args(): + program = preprocess_str(""" +%builtins ecdsa +@external +func f(): + return () +end +""") + + assert isinstance(program.identifiers.get_by_full_name( + WRAPPER_SCOPE + 'f'), FunctionDefinition) + + assert program.format() == """\ +%builtins ecdsa + +ret +call rel -1 +%{ memory[ap] = segments.add() %} +ap += 1 +[ap] = [[fp + (-4)]]; ap++ +[ap] = [[fp + (-4)] + 1]; ap++ +[ap] = 0; ap++ +[ap] = [ap + (-4)]; ap++ +ret +""" + + +def test_bad_implicit_arg_type(): + verify_exception(""" +%builtins pedersen + +struct HashBuiltin: +end + +@external +func f{pedersen_ptr : HashBuiltin}(): + return () +end +""", """ +file:?:?: While expanding the reference 'pedersen_ptr' in: +func f{pedersen_ptr : HashBuiltin}(): + ^ +file:?:?: Expected a 'felt' or a pointer type. Got: 'test_scope.HashBuiltin'. +func f{pedersen_ptr : HashBuiltin}(): + ^**********^ +""") + + +def test_unsupported_args(): + verify_exception(""" +@external +func fc(arg : felt*): + return () +end +""", """ +file:?:?: Unsupported argument type felt*. +func fc(arg : felt*): + ^***^ +""") + + +def test_invalid_hint(): + verify_exception(""" +@external +func fc(): + %{ __storage.merkle_update() %} + return () +end +""", """ +file:?:?: Hint is not whitelisted. +This may indicate that this library function cannot be used in StarkNet contracts. + %{ __storage.merkle_update() %} + ^*****************************^ +""") + + +def test_abi(): + program = preprocess_str(""" +@external +func f(a: felt) -> (b: felt, c: felt): + return (0, 1) +end + +@view +func g() -> (a: felt): + return (0) +end +""") + + assert program.abi == [ + { + 'inputs': [ + {'name': 'a', 'type': 'felt'} + ], + 'name': 'f', + 'outputs': [ + {'name': 'b', 'type': 'felt'}, + {'name': 'c', 'type': 'felt'} + ], + 'type': 'function', + }, + { + 'inputs': [], + 'name': 'g', + 'outputs': [ + {'name': 'a', 'type': 'felt'}, + ], + 'type': 'function', + 'stateMutability': 'view', + }, + ] diff --git a/src/starkware/starknet/compiler/storage_var.py b/src/starkware/starknet/compiler/storage_var.py new file mode 100644 index 00000000..ed27b991 --- /dev/null +++ b/src/starkware/starknet/compiler/storage_var.py @@ -0,0 +1,218 @@ +import dataclasses +from typing import Optional, Tuple + +from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt +from starkware.cairo.lang.compiler.ast.code_elements import ( + CodeElementEmptyLine, CodeElementFunction) +from starkware.cairo.lang.compiler.ast.formatting_utils import get_max_line_length +from starkware.cairo.lang.compiler.error_handling import Location +from starkware.cairo.lang.compiler.parser import parse +from starkware.cairo.lang.compiler.preprocessor.identifier_aware_visitor import ( + IdentifierAwareVisitor) +from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError +from starkware.starknet.definitions.constants import STARKNET_LANG_DIRECTIVE +from starkware.starknet.public.abi import get_storage_var_address + +STORAGE_VAR_DECORATOR = 'storage_var' +STORAGE_VAR_ATTR = 'storage_var' + + +def generate_storage_var_functions( + elm: CodeElementFunction, addr_func_body: str, + read_func_body: str, write_func_body: str) -> CodeElementFunction: + var_name = elm.identifier.name + autogen_filename = f'autogen/starknet/storage_var/{var_name}' + + code = f"""\ +namespace {var_name}: + from starkware.starknet.core.storage.storage import Storage, storage_read, storage_write + from starkware.cairo.common.cairo_builtins import HashBuiltin + from starkware.cairo.common.hash import hash2 + + func addr{{pedersen_ptr : HashBuiltin*}}() -> (res : felt): + {addr_func_body} + end + + func read{{storage_ptr : Storage*, pedersen_ptr : HashBuiltin*}}(): + {read_func_body} + end + + func write{{storage_ptr : Storage*, pedersen_ptr : HashBuiltin*}}(value : felt): + {write_func_body} + end +end\ +""" + + res = parse(autogen_filename, code, 'code_element', CodeElementFunction) + + # Copy the arguments and return values. + assert isinstance(res, CodeElementFunction) and res.element_type == 'namespace' + addr_func = res.code_block.code_elements[4].code_elm + assert isinstance(addr_func, CodeElementFunction) and addr_func.element_type == 'func' and \ + addr_func.identifier.name == 'addr' + addr_func.arguments = elm.arguments + + read_func = res.code_block.code_elements[6].code_elm + assert isinstance(read_func, CodeElementFunction) and read_func.element_type == 'func' and \ + read_func.identifier.name == 'read' + read_func.arguments = elm.arguments + read_func.returns = elm.returns + + write_func = res.code_block.code_elements[8].code_elm + assert isinstance(write_func, CodeElementFunction) and write_func.element_type == 'func' and \ + write_func.identifier.name == 'write' + # Append the value argument to the storage var arguments. + write_func.arguments = dataclasses.replace( + elm.arguments, + identifiers=elm.arguments.identifiers + write_func.arguments.identifiers) + + # Format and re-parse to get locations to a well-formatted code. + res = parse( + autogen_filename, res.format(get_max_line_length()), 'code_element', CodeElementFunction) + + res.additional_attributes[STORAGE_VAR_ATTR] = elm + + return res + + +def process_storage_var(elm: CodeElementFunction): + for commented_code_elm in elm.code_block.code_elements: + code_elm = commented_code_elm.code_elm + if not isinstance(code_elm, CodeElementEmptyLine): + if hasattr(code_elm, 'location'): + location = code_elm.location # type: ignore + else: + location = elm.identifier.location + raise PreprocessorError( + 'Storage variables must have an empty body.', + location=location) + + if elm.implicit_arguments is not None: + raise PreprocessorError( + 'Storage variables must have no implicit arguments.', + location=elm.implicit_arguments.location) + + for decorator in elm.decorators: + if decorator.name != STORAGE_VAR_DECORATOR: + raise PreprocessorError( + 'Storage variables must have no decorators in addition to ' + f'@{STORAGE_VAR_DECORATOR}.', + location=decorator.location) + + for arg in elm.arguments.identifiers: + arg_type = arg.get_type() + if not isinstance(arg_type, TypeFelt): + raise PreprocessorError( + 'Only felt arguments are supported in storage variables.', + location=arg_type.location) + + returns_felt = elm.returns is not None and len(elm.returns.identifiers) == 1 and \ + isinstance(elm.returns.identifiers[0].expr_type, TypeFelt) + if not returns_felt: + raise PreprocessorError( + 'Storage variables must return a single value of type felt.', + location=elm.returns.location if elm.returns is not None else elm.identifier.location) + + var_name = elm.identifier.name + addr = storage_var_name_to_base_addr(var_name) + addr_func_body = f'let res = {addr}\n' + for arg in elm.arguments.identifiers: + addr_func_body += \ + f'let (res) = hash2{{hash_ptr=pedersen_ptr}}(res, {arg.identifier.name})\n' + addr_func_body += 'return (res=res)\n' + + args = ', '.join(arg.identifier.name for arg in elm.arguments.identifiers) + + read_func_body = f"""\ +let (storage_addr) = addr({args}) +storage_read(address=storage_addr) +return ([ap - 1]) +""" + write_func_body = f"""\ +let (storage_addr) = addr({args}) +storage_write(address=storage_addr, value=value) +return () +""" + return generate_storage_var_functions( + elm, addr_func_body=addr_func_body, read_func_body=read_func_body, + write_func_body=write_func_body) + + +def storage_var_name_to_base_addr(var_name: str) -> int: + """ + Returns the base address of a StarkNet Storage variable, ignoring the storage var arguments. + """ + + return get_storage_var_address(var_name=var_name) + + +def is_storage_var(elm: CodeElementFunction) -> Tuple[bool, Optional[Location]]: + """ + Returns whether the given function has the storage var decorator. If it does, the location of + the decorator is returned. + """ + for decorator in elm.decorators: + if decorator.name == STORAGE_VAR_DECORATOR: + return True, decorator.location + return False, None + + +class StorageVarDeclVisitor(IdentifierAwareVisitor): + """ + Replaces @storage_var decorated functions with a namespace with empty functions. + After the struct collection phase is completed, those functions will be replaced by + functions will full implementation. + """ + + def _visit_default(self, obj): + return obj + + def visit_CodeElementFunction(self, elm: CodeElementFunction): + storage_var, storage_var_location = is_storage_var(elm) + if storage_var: + if self.file_lang != STARKNET_LANG_DIRECTIVE: + raise PreprocessorError( + '@storage_var can only be used in source files that contain the ' + '"%lang starknet" directive.', + location=storage_var_location) + # Add dummy references and calls that will be visited by the identifier collector + # and the dependency graph. + # Those statements will later be replaced by the real implementation. + addr_func_body = """ +let res = 0 +call hash2 +""" + read_func_body = """ +let storage_addr = 0 +call addr +call storage_read +""" + write_func_body = """ +let storage_addr = 0 +call addr +call storage_write +""" + return generate_storage_var_functions( + elm, addr_func_body=addr_func_body, read_func_body=read_func_body, + write_func_body=write_func_body) + + return elm + + +class StorageVarImplentationVisitor(IdentifierAwareVisitor): + """ + Replaces @storage_var decorated functions (obtained from the additional attribute + STORAGE_VAR_ATTR added by StorageVarDeclVisitor) with a namespace with read() and write() + functions. + """ + + def _visit_default(self, obj): + return obj + + def visit_CodeElementFunction(self, elm: CodeElementFunction): + attr = elm.additional_attributes.get(STORAGE_VAR_ATTR) + if attr is None: + return elm + + assert isinstance(attr, CodeElementFunction) + return process_storage_var(attr) diff --git a/src/starkware/starknet/compiler/storage_var_test.py b/src/starkware/starknet/compiler/storage_var_test.py new file mode 100644 index 00000000..f6649e91 --- /dev/null +++ b/src/starkware/starknet/compiler/storage_var_test.py @@ -0,0 +1,194 @@ +from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import ( + strip_comments_and_linebreaks) +from starkware.starknet.compiler.test_utils import preprocess_str, verify_exception +from starkware.starknet.public.abi import starknet_keccak + + +def test_storage_var_success(): + program = preprocess_str(""" +%lang starknet +from starkware.starknet.core.storage.storage import Storage +from starkware.cairo.common.cairo_builtins import HashBuiltin + +func g{storage_ptr : Storage*, pedersen_ptr : HashBuiltin*}(): + alloc_locals + let (x) = my_var.read() + my_var.write(value=x + 1) + local storage_ptr : Storage* = storage_ptr + let (my_var2_addr) = my_var2.addr(1, 2) + my_var2.write(1, 2, 3) + return () +end + +@storage_var +func my_var() -> (res : felt): + # Comment. + +end + +@storage_var +func my_var2(x, y) -> (res : felt): +end +""") + addr = starknet_keccak(b'my_var') + addr2 = starknet_keccak(b'my_var2') + expected_result = f"""\ +# Code for the dummy modules. +ret +ret +ret + +# Implementation of g. +ap += 1 +[ap] = [fp + (-4)]; ap++ # Push storage_ptr. +[ap] = [fp + (-3)]; ap++ # Push pedersen_ptr. +call rel 30 # Call my_var.read. +[ap] = [ap + (-3)]; ap++ # Push (updated) storage_ptr. +[ap] = [ap + (-3)]; ap++ # Push (updated) pedersen_ptr. +[ap] = [ap + (-3)] + 1; ap++ # Push value. +call rel 35 # Call my_var.write. +[fp] = [ap + (-2)] # Copy storage_ptr to a local variable. +[ap] = 1; ap++ # Push 1. +[ap] = 2; ap++ # Push 2. +call rel 38 # Call my_var2.addr. +[ap] = [fp]; ap++ # Push storage_ptr. +[ap] = [ap + (-3)]; ap++ # Push pedersen_ptr. +[ap] = 1; ap++ # Push 1. +[ap] = 2; ap++ # Push 2. +[ap] = 3; ap++ # Push 2. +call rel 38 # Call my_var2.write. +ret + +# Implementation of my_var.addr. +[ap] = [fp + (-3)]; ap++ # Return pedersen_ptr. +[ap] = {addr}; ap++ # Return address. +ret + +# Implementation of my_var.read. +[ap] = [fp + (-3)]; ap++ # Pass pedersen_ptr. +call rel -5 # Call my_var.addr(). +[ap] = [fp + (-4)]; ap++ # Pass storage_ptr. +[ap] = [ap + (-2)]; ap++ # Pass address. +call rel -41 # Call storage_read(). +[ap] = [ap + (-2)]; ap++ # Return storage_ptr. +[ap] = [ap + (-7)]; ap++ # Return (updated) pedersen_ptr. +[ap] = [ap + (-3)]; ap++ # Return value. +ret + +# Implementation of my_var.write. +[ap] = [fp + (-4)]; ap++ # Pass pedersen_ptr. +call rel -16 # Call my_var.addr(). +[ap] = [fp + (-5)]; ap++ # Pass storage_ptr. +[ap] = [ap + (-2)]; ap++ # Pass address. +[ap] = [fp + (-3)]; ap++ # Pass value. +call rel -52 # Call storage_write(). +[ap] = [ap + (-7)]; ap++ # Return (updated) pedersen_ptr. +ret + +# Implementation of my_var2.addr. +[ap] = [fp + (-5)]; ap++ # Push pedersen_ptr. +[ap] = {addr2}; ap++ # Push address. +[ap] = [fp + (-4)]; ap++ # Push x. +call rel -62 # Call hash2(res, x). +[ap] = [fp + (-3)]; ap++ # Push y. +call rel -65 # Call hash2(res, y). +ret + +# Implementation of my_var2.write. +[ap] = [fp + (-6)]; ap++ # Pass pedersen_ptr. +[ap] = [fp + (-5)]; ap++ # Pass x. +[ap] = [fp + (-4)]; ap++ # Pass y. +call rel -13 # Call my_var.addr(). +[ap] = [fp + (-7)]; ap++ # Pass storage_ptr. +[ap] = [ap + (-2)]; ap++ # Pass address. +[ap] = [fp + (-3)]; ap++ # Pass value. +call rel -74 # Call storage_write(). +[ap] = [ap + (-7)]; ap++ # Return (updated) pedersen_ptr. +ret +""" + assert program.format() == strip_comments_and_linebreaks(expected_result).lstrip() + + +def test_storage_var_failures(): + verify_exception(""" +@storage_var +func f() -> (res : felt): +end +""", """ +file:?:?: @storage_var can only be used in source files that contain the "%lang starknet" directive. +@storage_var +^**********^ +""") + verify_exception(""" +%lang starknet +@storage_var +func f(): + return () # Comment. +end +""", """ +file:?:?: Storage variables must have an empty body. + return () # Comment. + ^*******^ +""") + verify_exception(""" +%lang starknet +@storage_var +func f(): + 0 = 1 # Comment. +end +""", """ +file:?:?: Storage variables must have an empty body. +func f(): + ^ +""") + verify_exception(""" +%lang starknet +@storage_var +func f{x, y}(): +end +""", """ +file:?:?: Storage variables must have no implicit arguments. +func f{x, y}(): + ^**^ +""") + verify_exception(""" +%lang starknet +@storage_var +@invalid_decorator +func f(): +end +""", """ +file:?:?: Storage variables must have no decorators in addition to @storage_var. +@invalid_decorator +^****************^ +""") + verify_exception(""" +%lang starknet +@storage_var +func f(x, y : felt*): +end +""", """ +file:?:?: Only felt arguments are supported in storage variables. +func f(x, y : felt*): + ^***^ +""") + verify_exception(""" +%lang starknet +@storage_var +func f(): +end +""", """ +file:?:?: Storage variables must return a single value of type felt. +func f(): + ^ +""") + verify_exception(""" +%lang starknet +@storage_var +func f() -> (x: felt, y: felt): +end +""", """ +file:?:?: Storage variables must return a single value of type felt. +func f() -> (x: felt, y: felt): + ^**************^ +""") diff --git a/src/starkware/starknet/compiler/test_utils.py b/src/starkware/starknet/compiler/test_utils.py new file mode 100644 index 00000000..4823088e --- /dev/null +++ b/src/starkware/starknet/compiler/test_utils.py @@ -0,0 +1,51 @@ +from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME +from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError +from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import preprocess_str_ex +from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import ( + verify_exception as generic_verify_exception) +from starkware.cairo.lang.compiler.test_utils import read_file_from_dict +from starkware.starknet.compiler.starknet_pass_manager import starknet_pass_manager +from starkware.starknet.compiler.starknet_preprocessor import StarknetPreprocessedProgram + +TEST_MODULES = { + 'starkware.starknet.core.storage.storage': """ +struct Storage: +end + +func storage_read{storage_ptr : Storage*}(address : felt) -> (value : felt): + ret +end + +func storage_write{storage_ptr : Storage*}(address : felt, value : felt): + ret +end +""", + 'starkware.cairo.common.cairo_builtins': """ +struct HashBuiltin: +end +""", + 'starkware.cairo.common.hash': """ +from starkware.cairo.common.cairo_builtins import HashBuiltin + +func hash2{hash_ptr : HashBuiltin*}(x, y) -> (result): + ret +end +"""} + + +def preprocess_str(code: str) -> StarknetPreprocessedProgram: + preprocessed = preprocess_str_ex( + code=code, + pass_manager=starknet_pass_manager( + prime=DEFAULT_PRIME, read_module=read_file_from_dict(TEST_MODULES))) + assert isinstance(preprocessed, StarknetPreprocessedProgram) + return preprocessed + + +def verify_exception(code: str, error: str, exc_type=PreprocessorError): + return generic_verify_exception( + code=code, + error=error, + pass_manager=starknet_pass_manager( + prime=DEFAULT_PRIME, read_module=read_file_from_dict(TEST_MODULES)), + exc_type=exc_type) diff --git a/src/starkware/starknet/core/CMakeLists.txt b/src/starkware/starknet/core/CMakeLists.txt new file mode 100644 index 00000000..a29c6bf3 --- /dev/null +++ b/src/starkware/starknet/core/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(storage) diff --git a/src/starkware/starknet/core/storage/CMakeLists.txt b/src/starkware/starknet/core/storage/CMakeLists.txt new file mode 100644 index 00000000..4f356fc6 --- /dev/null +++ b/src/starkware/starknet/core/storage/CMakeLists.txt @@ -0,0 +1,20 @@ +python_lib(starknet_cairo_storage_lib + PREFIX starkware/starknet/core/storage + + FILES + storage.cairo +) + +full_python_test(starknet_cairo_storage_lib_test + PREFIX starkware/starknet/core/storage + PYTHON python3.7 + TESTED_MODULES starkware/starknet/core/storage + + FILES + storage_test.py + + LIBS + cairo_function_runner_lib + starknet_cairo_storage_lib + pip_pytest +) diff --git a/src/starkware/starknet/core/storage/__init__.py b/src/starkware/starknet/core/storage/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/starkware/starknet/core/storage/storage.cairo b/src/starkware/starknet/core/storage/storage.cairo new file mode 100644 index 00000000..58384041 --- /dev/null +++ b/src/starkware/starknet/core/storage/storage.cairo @@ -0,0 +1,36 @@ +from starkware.cairo.common.dict_access import DictAccess + +struct Storage: +end + +# Reads a value from a given address in the storage. +func storage_read{storage_ptr : Storage*}(address : felt) -> (value : felt): + let dict_ptr = cast(storage_ptr, DictAccess*) + + dict_ptr.key = address + + # Put storage in the right place for return value optimization. + tempvar storage_ptr = storage_ptr + DictAccess.SIZE + %{ ids.dict_ptr.prev_value = __storage.read(address=ids.dict_ptr.key) %} + # Make sure prev_value == new_value. + tempvar value = dict_ptr.prev_value + dict_ptr.new_value = value + + return (value=value) +end + +# Writes the given value to the given address in the storage. +func storage_write{storage_ptr : Storage*}(address : felt, value : felt): + let dict_ptr = cast(storage_ptr, DictAccess*) + + # Note that soundness-wise it is ok to set prev_value in the hint. + dict_ptr.key = address + dict_ptr.new_value = value + %{ + ids.dict_ptr.prev_value = __storage.read(address=ids.dict_ptr.key) + __storage.write(address=ids.dict_ptr.key, value=ids.dict_ptr.new_value) + %} + + let storage_ptr = storage_ptr + DictAccess.SIZE + return () +end diff --git a/src/starkware/starknet/core/storage/storage_test.py b/src/starkware/starknet/core/storage/storage_test.py new file mode 100644 index 00000000..8dfe7bdd --- /dev/null +++ b/src/starkware/starknet/core/storage/storage_test.py @@ -0,0 +1,72 @@ +import os +from unittest.mock import MagicMock + +import pytest + +from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner +from starkware.cairo.common.structs import CairoStructFactory, CairoStructProxy +from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME +from starkware.cairo.lang.compiler.cairo_compile import compile_cairo_files +from starkware.cairo.lang.compiler.program import Program + +CAIRO_FILE = os.path.join(os.path.dirname(__file__), 'storage.cairo') + + +@pytest.fixture +def program() -> Program: + return compile_cairo_files([CAIRO_FILE], prime=DEFAULT_PRIME) + + +@pytest.fixture +def structs(program: Program) -> CairoStructProxy: + return CairoStructFactory.from_program(program).structs + + +@pytest.fixture +def runner(program: Program) -> CairoFunctionRunner: + return CairoFunctionRunner(program) + + +def test_storage_read(runner: CairoFunctionRunner, structs: CairoStructProxy): + stark_net_storage = MagicMock(name='storage') + + storage_value = 45 + stark_net_storage.read.return_value = storage_value + + storage_ptr = runner.segments.add() + address = 17 + + runner.run( + 'storage_read', storage_ptr=storage_ptr, address=address, + hint_locals={'__storage': stark_net_storage}) + + storage_end, value = runner.get_return_values(2) + assert value == storage_value + assert runner.vm_memory.get_range( + storage_ptr, storage_end - storage_ptr) == list(structs.DictAccess( + key=address, prev_value=value, new_value=value)) + + stark_net_storage.read.assert_called_once_with(address=address) + + +def test_storage_write(runner: CairoFunctionRunner, structs: CairoStructProxy): + stark_net_storage = MagicMock(name='storage') + + orig_value = 45 + new_value = 42 + stark_net_storage.read.return_value = orig_value + + storage_ptr = runner.segments.add() + address = 17 + + runner.run( + 'storage_write', storage_ptr=storage_ptr, address=address, value=new_value, + hint_locals={'__storage': stark_net_storage}) + + storage_end, = runner.get_return_values(1) + assert runner.vm_memory.get_range( + storage_ptr, storage_end - storage_ptr) == list(structs.DictAccess( + key=address, prev_value=orig_value, new_value=new_value)) + + stark_net_storage.read.assert_called_once_with(address=address) + stark_net_storage.write.assert_called_once_with(address=address, value=new_value) diff --git a/src/starkware/starknet/definitions/CMakeLists.txt b/src/starkware/starknet/definitions/CMakeLists.txt new file mode 100644 index 00000000..0e027e26 --- /dev/null +++ b/src/starkware/starknet/definitions/CMakeLists.txt @@ -0,0 +1,15 @@ +python_lib(starknet_definitions_lib + PREFIX starkware/starknet/definitions + + FILES + constants.py + error_codes.py + fields.py + transaction_type.py + + LIBS + cairo_vm_crypto_lib + starkware_crypto_lib + starkware_utils_lib + pip_marshmallow +) diff --git a/src/starkware/starknet/definitions/__init__.py b/src/starkware/starknet/definitions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/starkware/starknet/definitions/constants.py b/src/starkware/starknet/definitions/constants.py new file mode 100644 index 00000000..b52b261d --- /dev/null +++ b/src/starkware/starknet/definitions/constants.py @@ -0,0 +1,16 @@ +from starkware.crypto.signature.signature import FIELD_PRIME + +STARKNET_LANG_DIRECTIVE = 'starknet' + +FIELD_SIZE = FIELD_PRIME + +CALL_DATA_ELEMENT_LOWER_BOUND = 0 +CALL_DATA_ELEMENT_UPPER_BOUND = FIELD_SIZE +CONTRACT_ADDRESS_BITS = 251 +CONTRACT_ADDRESS_LOWER_BOUND = 0 +CONTRACT_ADDRESS_UPPER_BOUND = 2**CONTRACT_ADDRESS_BITS +CONTRACT_STATES_MERKLE_TREE_HEIGHT = 251 +ENTRY_POINT_OFFSET_LOWER_BOUND = 0 +ENTRY_POINT_OFFSET_UPPER_BOUND = FIELD_SIZE +ENTRY_POINT_SELECTOR_LOWER_BOUND = 0 +ENTRY_POINT_SELECTOR_UPPER_BOUND = FIELD_SIZE diff --git a/src/starkware/starknet/definitions/error_codes.py b/src/starkware/starknet/definitions/error_codes.py new file mode 100644 index 00000000..616f5465 --- /dev/null +++ b/src/starkware/starknet/definitions/error_codes.py @@ -0,0 +1,23 @@ +from enum import auto + +from starkware.starkware_utils.error_handling import ErrorCode + + +class StarknetErrorCode(ErrorCode): + BLOCK_NOT_FOUND = 0 + CONTRACT_ADDRESS_UNAVAILABLE = auto() + ENTRY_POINT_NOT_FOUND_IN_CONTRACT = auto() + INVALID_CONTRACT_DEFINITION = auto() + INVALID_RETURN_DATA = auto() + INVALID_TRANSACTION_ID = auto() + MULTIPLE_ENTRY_POINTS_MATCH_SELECTOR = auto() + OUT_OF_RANGE_BLOCK_ID = auto() + OUT_OF_RANGE_CALL_DATA_ELEMENT = auto() + OUT_OF_RANGE_CONTRACT_ADDRESS = auto() + OUT_OF_RANGE_CONTRACT_STORAGE_KEY = auto() + OUT_OF_RANGE_ENTRY_POINT_OFFSET = auto() + OUT_OF_RANGE_ENTRY_POINT_SELECTOR = auto() + SECURITY_ERROR = auto() + TRANSACTION_FAILED = auto() + UNEXPECTED_FAILURE = auto() + UNINITIALIZED_CONTRACT = auto() diff --git a/src/starkware/starknet/definitions/fields.py b/src/starkware/starknet/definitions/fields.py new file mode 100644 index 00000000..bd851701 --- /dev/null +++ b/src/starkware/starknet/definitions/fields.py @@ -0,0 +1,72 @@ +import marshmallow.fields as mfields + +from starkware.starknet.definitions import constants +from starkware.starknet.definitions.error_codes import StarknetErrorCode +from starkware.starkware_utils.field_validators import validate_non_negative, validate_positive +from starkware.starkware_utils.marshmallow_dataclass_fields import BytesAsHex, IntAsStr +from starkware.starkware_utils.validated_fields import ( + RangeValidatedField, int_as_hex_metadata, sequential_id_metadata) + +# Fields data: validation data, dataclass metadata. + +block_id_metadata = sequential_id_metadata(field_name='block_id') + +previous_block_id_metadata = sequential_id_metadata( + field_name='previous_block_id', allow_previous_id=True) + +sequence_number_metadata = sequential_id_metadata(field_name='sequence_number') + +CallDataElementField = RangeValidatedField( + lower_bound=constants.CALL_DATA_ELEMENT_LOWER_BOUND, + upper_bound=constants.CALL_DATA_ELEMENT_UPPER_BOUND, + name_in_error_message='Call data element', + out_of_range_error_code=StarknetErrorCode.OUT_OF_RANGE_ENTRY_POINT_SELECTOR) + +call_data_metadata = dict( + marshmallow_field=mfields.List(IntAsStr(validate=CallDataElementField.validate))) + +ContractAddressField = RangeValidatedField( + lower_bound=constants.CONTRACT_ADDRESS_LOWER_BOUND, + upper_bound=constants.CONTRACT_ADDRESS_UPPER_BOUND, + name_in_error_message='Contract address', + out_of_range_error_code=StarknetErrorCode.OUT_OF_RANGE_CONTRACT_ADDRESS) + +contract_address_metadata = int_as_hex_metadata(validated_field=ContractAddressField) + +contract_definitions_metadata = dict(marshmallow_field=mfields.Dict(keys=BytesAsHex)) + +contract_hash_metadata = dict(marshmallow_field=BytesAsHex(required=True)) + +contract_storage_merkle_height_metadata = dict( + marshmallow_field=mfields.Integer( + strict=True, required=True, validate=validate_positive('contract_storage_merkle_height'))) + +EntryPointSelectorField = RangeValidatedField( + lower_bound=constants.ENTRY_POINT_SELECTOR_LOWER_BOUND, + upper_bound=constants.ENTRY_POINT_SELECTOR_UPPER_BOUND, + name_in_error_message='Entry point selector', + out_of_range_error_code=StarknetErrorCode.OUT_OF_RANGE_ENTRY_POINT_SELECTOR) + +entry_point_selector_metadata = int_as_hex_metadata(validated_field=EntryPointSelectorField) + +EntryPointOffsetField = RangeValidatedField( + lower_bound=constants.ENTRY_POINT_OFFSET_LOWER_BOUND, + upper_bound=constants.ENTRY_POINT_OFFSET_UPPER_BOUND, + name_in_error_message='Entry point offset', + out_of_range_error_code=StarknetErrorCode.OUT_OF_RANGE_ENTRY_POINT_OFFSET) + +entry_point_offset_metadata = int_as_hex_metadata(validated_field=EntryPointOffsetField) + +global_state_merkle_height_metadata = dict( + marshmallow_field=mfields.Integer( + strict=True, required=True, validate=validate_non_negative('global_state_merkle_height'))) + +state_root_metadata = dict(marshmallow_field=BytesAsHex(required=True)) + +timestamp_metadata = dict( + marshmallow_field=mfields.Integer( + strict=True, required=True, validate=validate_non_negative('timestamp'))) + +invoke_tx_n_steps_metadata = dict( + marshmallow_field=mfields.Integer( + strict=True, required=True, validate=validate_non_negative('invoke_tx_n_steps'))) diff --git a/src/starkware/starknet/definitions/transaction_type.py b/src/starkware/starknet/definitions/transaction_type.py new file mode 100644 index 00000000..593e0d77 --- /dev/null +++ b/src/starkware/starknet/definitions/transaction_type.py @@ -0,0 +1,6 @@ +from enum import Enum, auto + + +class TransactionType(Enum): + DEPLOY = 0 + INVOKE_FUNCTION = auto() diff --git a/src/starkware/starknet/public/CMakeLists.txt b/src/starkware/starknet/public/CMakeLists.txt new file mode 100644 index 00000000..2260a4cd --- /dev/null +++ b/src/starkware/starknet/public/CMakeLists.txt @@ -0,0 +1,24 @@ +python_lib(starknet_abi_lib + PREFIX starkware/starknet/public + FILES + abi.py + + LIBS + cairo_vm_crypto_lib + pip_eth_hash + pip_pycryptodome +) + + +full_python_test(starknet_abi_lib_test + PREFIX starkware/starknet/public + PYTHON python3.7 + TESTED_MODULES starkware/starknet/public/ + + FILES + abi_test.py + + LIBS + starknet_abi_lib + pip_pytest +) diff --git a/src/starkware/starknet/public/__init__.py b/src/starkware/starknet/public/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/starkware/starknet/public/abi.py b/src/starkware/starknet/public/abi.py new file mode 100644 index 00000000..2174a9c9 --- /dev/null +++ b/src/starkware/starknet/public/abi.py @@ -0,0 +1,26 @@ +from eth_hash.auto import keccak + +from starkware.cairo.lang.vm.crypto import pedersen_hash + +MASK_250 = 2 ** 250 - 1 + + +def starknet_keccak(data: bytes) -> int: + """ + A variant of eth-keccak that computes a value that fits in a StarkNet field element. + """ + + return int.from_bytes(keccak(data), 'big') & MASK_250 + + +def get_storage_var_address(var_name: str, *args) -> int: + """ + Returns the storage address of a StarkNet storage variable given its name and arguments. + """ + res = starknet_keccak(var_name.encode('utf8')) + + for arg in args: + assert isinstance(arg, int), f'Expected arguments to be integers. Found: {arg}.' + res = pedersen_hash(res, arg) + + return res diff --git a/src/starkware/starknet/public/abi_test.py b/src/starkware/starknet/public/abi_test.py new file mode 100644 index 00000000..a27a6b72 --- /dev/null +++ b/src/starkware/starknet/public/abi_test.py @@ -0,0 +1,7 @@ +from starkware.starknet.public.abi import starknet_keccak + + +def test_starknet_keccak(): + value = starknet_keccak(b'hello') + assert value == 0x8aff950685c2ed4bc3174f3472287b56d9517b9c948127319a09a7a36deac8 + assert value < 2**250 diff --git a/src/starkware/starknet/scripts/CMakeLists.txt b/src/starkware/starknet/scripts/CMakeLists.txt new file mode 100644 index 00000000..c66c7d70 --- /dev/null +++ b/src/starkware/starknet/scripts/CMakeLists.txt @@ -0,0 +1,10 @@ +python_lib(starknet_script_lib + PREFIX starkware/starknet/scripts + FILES + starknet + starknet-compile + + LIBS + starknet_cli_lib + starknet_compile_lib +) diff --git a/src/starkware/starknet/scripts/starknet b/src/starkware/starknet/scripts/starknet new file mode 100755 index 00000000..a015287a --- /dev/null +++ b/src/starkware/starknet/scripts/starknet @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 + +import os +import sys +import asyncio + + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..')) +from starkware.starknet.cli.starknet_cli import main # noqa + +if __name__ == '__main__': + sys.exit(asyncio.run(main())) diff --git a/src/starkware/starknet/scripts/starknet-compile b/src/starkware/starknet/scripts/starknet-compile new file mode 100755 index 00000000..c6998082 --- /dev/null +++ b/src/starkware/starknet/scripts/starknet-compile @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 + +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..')) +from starkware.starknet.compiler.compile import main # noqa + +if __name__ == '__main__': + sys.exit(main()) diff --git a/src/starkware/starknet/security/CMakeLists.txt b/src/starkware/starknet/security/CMakeLists.txt new file mode 100644 index 00000000..566bdd41 --- /dev/null +++ b/src/starkware/starknet/security/CMakeLists.txt @@ -0,0 +1,46 @@ +python_lib(starknet_security_lib + PREFIX starkware/starknet/security + FILES + secure_hints.py + + LIBS + cairo_compile_lib + cairo_run_lib + pip_marshmallow + pip_marshmallow_dataclass +) + +full_python_test(starknet_hints_latest_whitelist_test + PREFIX starkware/starknet/security + PYTHON python3.7 + TESTED_MODULES starkware/starknet/security + + FILES + latest_whitelist_test.py + secure_hints_test.py + starknet_common.cairo + + LIBS + cairo_common_lib + cairo_constants_lib + starknet_cairo_storage_lib + starknet_security_lib + pip_pytest +) + +python_exe(starknet_hints_latest_whitelist_fix + VENV starknet_hints_latest_whitelist_test_venv + MODULE starkware.starknet.security.latest_whitelist_test + ARGS "--fix" +) + +python_lib(starknet_hints_whitelist_lib + PREFIX starkware/starknet/security + + FILES + hints_whitelist.py + whitelists/latest.json + + LIBS + starknet_security_lib +) diff --git a/src/starkware/starknet/security/__init__.py b/src/starkware/starknet/security/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/starkware/starknet/security/hints_whitelist.py b/src/starkware/starknet/security/hints_whitelist.py new file mode 100644 index 00000000..4797f227 --- /dev/null +++ b/src/starkware/starknet/security/hints_whitelist.py @@ -0,0 +1,8 @@ +import os + +from starkware.starknet.security.secure_hints import HintsWhitelist + + +def get_hints_whitelist() -> HintsWhitelist: + return HintsWhitelist.from_file( + filename=os.path.join(os.path.dirname(__file__), 'whitelists/latest.json')) diff --git a/src/starkware/starknet/security/latest_whitelist_test.py b/src/starkware/starknet/security/latest_whitelist_test.py new file mode 100644 index 00000000..f64e5bbc --- /dev/null +++ b/src/starkware/starknet/security/latest_whitelist_test.py @@ -0,0 +1,41 @@ +import argparse +import os + +from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME +from starkware.cairo.lang.compiler.cairo_compile import compile_cairo_files +from starkware.python.utils import get_source_dir_path +from starkware.starknet.security.secure_hints import HintsWhitelist + +""" +Fix using the starknet_hints_latest_whitelist_fix executable. +""" + + +CAIRO_FILE = os.path.join(os.path.dirname(__file__), 'starknet_common.cairo') + + +def run(fix: bool): + program = compile_cairo_files(files=[CAIRO_FILE], prime=DEFAULT_PRIME) + filename = get_source_dir_path('src/starkware/starknet/security/whitelists/latest.json') + whitelist = HintsWhitelist.from_program(program) + if fix: + data = HintsWhitelist.Schema().dumps(whitelist, indent=4, sort_keys=True) + with open(filename, 'w') as fp: + fp.write(data) + fp.write('\n') + return + + expected_whitelist = HintsWhitelist.from_file(filename) + assert whitelist == expected_whitelist + + +def test_latest_whitelist(): + run(fix=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Checks/fixes the latest StarkNet hint whitelist file.') + parser.add_argument('--fix', action='store_true', help='Fix the latest whitelist file.') + args = parser.parse_args() + run(fix=args.fix) diff --git a/src/starkware/starknet/security/secure_hints.py b/src/starkware/starknet/security/secure_hints.py new file mode 100644 index 00000000..3b152452 --- /dev/null +++ b/src/starkware/starknet/security/secure_hints.py @@ -0,0 +1,128 @@ +from dataclasses import field +from typing import ClassVar, Dict, List, Set, Type + +import marshmallow +import marshmallow.fields as mfields +import marshmallow_dataclass + +from starkware.cairo.lang.compiler.preprocessor.flow import ReferenceManager +from starkware.cairo.lang.compiler.program import CairoHint, Program + + +class SetField(mfields.List): + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + res = super()._serialize(value, attr, obj, **kwargs) + return sorted(res, key=lambda x: (x['name'], x['expr'])) + + def _deserialize(self, *args, **kwargs): + return set(super()._deserialize(*args, **kwargs)) + + +class InsecureHintError(Exception): + pass + + +@marshmallow_dataclass.dataclass(frozen=True) +class NamedExpression: + name: str + expr: str + + def __lt__(self, other): + if not isinstance(other, NamedExpression): + return NotImplemented + return (self.name, self.expr) < (other.name, other.expr) + + Schema: ClassVar[marshmallow.Schema] + + +@marshmallow_dataclass.dataclass +class HintsWhitelistEntry: + hint_lines: List[str] + allowed_expressions: Set[NamedExpression] = field( + metadata=dict(marshmallow_field=SetField(mfields.Nested(NamedExpression.Schema)))) + + Schema: ClassVar[Type[marshmallow.Schema]] + + def serialize(self) -> dict: + return HintsWhitelistEntry.Schema().dump(self) + + +class HintsWhitelistDict(mfields.Field): + """ + A field that behaves like a dictionary from hint to a set of allowed expressions, but + serializes as a list where the hint is split to lines. + """ + + def _serialize(self, value, attr, obj, **kwargs): + return [ + HintsWhitelistEntry( + hint_lines.split('\n'), allowed_expressions=allowed_expressions).serialize() + for hint_lines, allowed_expressions in sorted(value.items())] + + def _deserialize(self, value, attr, data, **kwargs) -> Dict[str, Set[NamedExpression]]: + entries = [HintsWhitelistEntry.Schema().load(entry) for entry in value] + return {'\n'.join(entry.hint_lines): entry.allowed_expressions for entry in entries} + + +@marshmallow_dataclass.dataclass +class HintsWhitelist: + """ + Checks the security of hints in a Cairo program against a whitelist. + """ + + # Maps a hint string to the set of allowed expressions in its references. + allowed_reference_expressions_for_hint: Dict[str, Set[NamedExpression]] = field( + metadata=dict(marshmallow_field=HintsWhitelistDict())) + Schema: ClassVar[Type[marshmallow.Schema]] + + # Serialization operations. + @classmethod + def from_file(cls, filename: str) -> 'HintsWhitelist': + with open(filename, 'r') as fp: + return cls.Schema().loads(fp.read()) + + @classmethod + def from_program(cls, program: Program) -> 'HintsWhitelist': + """ + Creates a whitelist from all the hints in an existing program. + """ + whitelist = cls(allowed_reference_expressions_for_hint={}) + for hint in program.hints.values(): + whitelist.add_hint_to_whitelist(hint, program.reference_manager) + return whitelist + + def add_hint_to_whitelist(self, hint: CairoHint, reference_manager: ReferenceManager): + self.allowed_reference_expressions_for_hint.setdefault(hint.code, set()).update( + self._get_hint_reference_expressions(hint, reference_manager)) + + # Reading operations. + def verify_program_hint_secure(self, program: Program): + """ + Determines whether a Cairo program is hint-secure. This happens when all the + hints and their associated reference expressions exist within a given whitelist. + """ + for hint in program.hints.values(): + self.verify_hint_secure( + hint=hint, reference_manager=program.reference_manager) + + def verify_hint_secure(self, hint: CairoHint, reference_manager: ReferenceManager): + allowed_expressions = self.allowed_reference_expressions_for_hint.get(hint.code) + if allowed_expressions is None: + raise InsecureHintError(f'Hint is not whitelisted:\n{hint.code}') + + expressions = self._get_hint_reference_expressions(hint, reference_manager) + invalid_expressions = expressions - allowed_expressions + if invalid_expressions: + raise InsecureHintError( + f'Forbidden expressions in hint "{hint.code}":\n{sorted(invalid_expressions)}') + + def _get_hint_reference_expressions( + self, hint: CairoHint, reference_manager: ReferenceManager) -> \ + Set[NamedExpression]: + ref_exprs: Set[NamedExpression] = set() + for ref_name, ref_id in hint.flow_tracking_data.reference_ids.items(): + ref = reference_manager.get_ref(ref_id) + ref_exprs.add(NamedExpression(name=str(ref_name), expr=ref.value.format())) + return ref_exprs diff --git a/src/starkware/starknet/security/secure_hints_test.py b/src/starkware/starknet/security/secure_hints_test.py new file mode 100644 index 00000000..f0cfcf00 --- /dev/null +++ b/src/starkware/starknet/security/secure_hints_test.py @@ -0,0 +1,112 @@ +import re + +import pytest + +from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME +from starkware.cairo.lang.compiler.cairo_compile import compile_cairo +from starkware.starknet.security.secure_hints import HintsWhitelist, InsecureHintError + +ALLOWED_CODE = """ +func f(a: felt, b: felt): + %{ + This is a hint. + %} + ap += 5 + ret +end +""" + +GOOD_CODES = [ + """ +func f(b: felt): + %{ + This is a hint. + %} + ap += 5 + ret +end +""", +] + +BAD_CODES = [ + (""" +func f(c: felt, a: felt, b: felt): + %{ + This is a hint. + %} + ap += 5 + ret +end +""", """Forbidden expressions in hint "This is a hint.": +[NamedExpression(name='__main__.f.c', expr='[cast(fp + (-5), felt*)]')]"""), + (""" +func f(a: felt, b: felt): + %{ + This is a bad hint. + %} + ap += 5 + ret +end +""", 'is not whitelisted'), + (""" +func f(b: felt, a: felt): + %{ + This is a hint. + %} + ap += 5 + ret +end +""", + """Forbidden expressions in hint "This is a hint.": +[NamedExpression(name='__main__.f.a', expr='[cast(fp + (-3), felt*)]'), \ +NamedExpression(name='__main__.f.b', expr='[cast(fp + (-4), felt*)]')]""" + ), +] + + +def test_secure_hints_cases(): + template_program = compile_cairo(ALLOWED_CODE, DEFAULT_PRIME) + whitelist = HintsWhitelist.from_program(template_program) + for good_code in GOOD_CODES: + program = compile_cairo(good_code, DEFAULT_PRIME) + whitelist.verify_program_hint_secure(program) + for bad_code, message in BAD_CODES: + program = compile_cairo(bad_code, DEFAULT_PRIME) + with pytest.raises(InsecureHintError, match=re.escape(message)): + whitelist.verify_program_hint_secure(program) + + +def test_secure_hints_serialization(): + template_program = compile_cairo(ALLOWED_CODE, DEFAULT_PRIME) + whitelist = HintsWhitelist.from_program(template_program) + data = HintsWhitelist.Schema().dumps(whitelist) + whitelist = HintsWhitelist.Schema().loads(data) + for good_code in GOOD_CODES: + program = compile_cairo(good_code, DEFAULT_PRIME) + whitelist.verify_program_hint_secure(program) + + +def test_collision(): + """ + Tests multiple hints with the same code but different reference expressions. + """ + code = """ +func f(): + let b = [ap] + %{ + ids.b = 1 + %} + ret +end +func g(): + let b = [ap - 10] + %{ + ids.b = 1 + %} + ret +end +""" + program = compile_cairo(code, DEFAULT_PRIME) + whitelist = HintsWhitelist.from_program(program) + assert len(whitelist.allowed_reference_expressions_for_hint) == 1 + whitelist.verify_program_hint_secure(program) diff --git a/src/starkware/starknet/security/starknet_common.cairo b/src/starkware/starknet/security/starknet_common.cairo new file mode 100644 index 00000000..4ad26753 --- /dev/null +++ b/src/starkware/starknet/security/starknet_common.cairo @@ -0,0 +1,14 @@ +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.default_dict import default_dict_finalize, default_dict_new +from starkware.cairo.common.dict import dict_read, dict_squash, dict_update, dict_write +from starkware.cairo.common.find_element import find_element, search_sorted, search_sorted_lower +from starkware.cairo.common.math import ( + abs_value, assert_in_range, assert_le, assert_le_250_bit, assert_le_felt, assert_lt, + assert_lt_felt, assert_nn, assert_nn_le, assert_not_equal, assert_not_zero, sign, + signed_div_rem, split_felt, unsigned_div_rem) +from starkware.cairo.common.math_cmp import ( + is_in_range, is_le, is_le_felt, is_nn, is_nn_le, is_not_zero) +from starkware.cairo.common.memcpy import memcpy +from starkware.cairo.common.signature import verify_ecdsa_signature +from starkware.cairo.common.squash_dict import squash_dict +from starkware.starknet.core.storage.storage import storage_read, storage_write diff --git a/src/starkware/starknet/security/whitelists/latest.json b/src/starkware/starknet/security/whitelists/latest.json new file mode 100644 index 00000000..1215b3c9 --- /dev/null +++ b/src/starkware/starknet/security/whitelists/latest.json @@ -0,0 +1,1394 @@ +{ + "allowed_reference_expressions_for_hint": [ + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.dict.dict_squash.dict_accesses_end" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.dict.dict_squash.dict_accesses_start" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.dict.dict_squash.range_check_ptr" + } + ], + "hint_lines": [ + "# Prepare arguments for dict_new. In particular, the same dictionary values should be copied", + "# to the new (squashed) dictionary.", + "vm_enter_scope({", + " # Make __dict_manager accessible.", + " '__dict_manager': __dict_manager,", + " # Create a copy of the dict, in case it changes in the future.", + " 'initial_dict': dict(__dict_manager.get_dict(ids.dict_accesses_end)),", + "})" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(ap + (-1), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.dict.dict_squash.__temp28" + }, + { + "expr": "[cast(fp + (-3), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.dict.dict_squash.dict_accesses_end" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.dict.dict_squash.dict_accesses_start" + }, + { + "expr": "[cast(ap + (-2), felt*)]", + "name": "starkware.cairo.common.dict.dict_squash.range_check_ptr" + }, + { + "expr": "[cast(ap + (-1), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.dict.dict_squash.squashed_dict_end" + }, + { + "expr": "[cast(fp, starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.dict.dict_squash.squashed_dict_start" + } + ], + "hint_lines": [ + "# Update the DictTracker's current_ptr to point to the end of the squashed dict.", + "__dict_manager.get_tracker(ids.squashed_dict_start).current_ptr = \\", + " ids.squashed_dict_end.address_" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-6), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.dict.dict_update.dict_ptr" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.dict.dict_update.key" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.dict.dict_update.new_value" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.dict.dict_update.prev_value" + } + ], + "hint_lines": [ + "# Verify dict pointer and prev value.", + "dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)", + "current_value = dict_tracker.data[ids.key]", + "assert current_value == ids.prev_value, \\", + " f'Wrong previous value in dict. Got {ids.prev_value}, expected {current_value}.'", + "", + "# Update value.", + "dict_tracker.data[ids.key] = ids.new_value", + "dict_tracker.current_ptr += ids.DictAccess.SIZE" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-6), felt**)]", + "name": "starkware.cairo.common.find_element.search_sorted_lower.array_ptr" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.find_element.search_sorted_lower.elm_size" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "starkware.cairo.common.find_element.search_sorted_lower.index" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.find_element.search_sorted_lower.key" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.find_element.search_sorted_lower.n_elms" + }, + { + "expr": "[cast(fp + (-7), felt*)]", + "name": "starkware.cairo.common.find_element.search_sorted_lower.range_check_ptr" + } + ], + "hint_lines": [ + "array_ptr = ids.array_ptr", + "elm_size = ids.elm_size", + "assert isinstance(elm_size, int) and elm_size > 0, \\", + " f'Invalid value for elm_size. Got: {elm_size}.'", + "", + "n_elms = ids.n_elms", + "assert isinstance(n_elms, int) and n_elms >= 0, \\", + " f'Invalid value for n_elms. Got: {n_elms}.'", + "if '__find_element_max_size' in globals():", + " assert n_elms <= __find_element_max_size, \\", + " f'find_element() can only be used with n_elms<={__find_element_max_size}. ' \\", + " f'Got: n_elms={n_elms}.'", + "", + "for i in range(n_elms):", + " if memory[array_ptr + elm_size * i] >= ids.key:", + " ids.index = i", + " break", + "else:", + " ids.index = n_elms" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-6), felt**)]", + "name": "starkware.cairo.common.find_element.find_element.array_ptr" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.find_element.find_element.elm_size" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "starkware.cairo.common.find_element.find_element.index" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.find_element.find_element.key" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.find_element.find_element.n_elms" + }, + { + "expr": "[cast(fp + (-7), felt*)]", + "name": "starkware.cairo.common.find_element.find_element.range_check_ptr" + } + ], + "hint_lines": [ + "array_ptr = ids.array_ptr", + "elm_size = ids.elm_size", + "assert isinstance(elm_size, int) and elm_size > 0, \\", + " f'Invalid value for elm_size. Got: {elm_size}.'", + "key = ids.key", + "", + "if '__find_element_index' in globals():", + " ids.index = __find_element_index", + " found_key = memory[array_ptr + elm_size * __find_element_index]", + " assert found_key == key, \\", + " f'Invalid index found in __find_element_index. index: {__find_element_index}, ' \\", + " f'expected key {key}, found key: {found_key}.'", + " # Delete __find_element_index to make sure it's not used for the next calls.", + " del __find_element_index", + "else:", + " n_elms = ids.n_elms", + " assert isinstance(n_elms, int) and n_elms >= 0, \\", + " f'Invalid value for n_elms. Got: {n_elms}.'", + " if '__find_element_max_size' in globals():", + " assert n_elms <= __find_element_max_size, \\", + " f'find_element() can only be used with n_elms<={__find_element_max_size}. ' \\", + " f'Got: n_elms={n_elms}.'", + "", + " for i in range(n_elms):", + " if memory[array_ptr + elm_size * i] == key:", + " ids.index = i", + " break", + " else:", + " raise ValueError(f'Key {key} was not found.')" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.big_keys" + }, + { + "expr": "[cast(fp + (-8), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses" + }, + { + "expr": "[cast(fp + (-7), felt**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses_end_minus1" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_diff" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.first_value" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.key" + }, + { + "expr": "cast(ap + (-3), starkware.cairo.common.squash_dict.squash_dict_inner.LoopLocals*)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.last_loop_locals" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.n_used_accesses" + }, + { + "expr": "[cast(fp + (-9), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.range_check_ptr" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.remaining_accesses" + }, + { + "expr": "[cast(fp + 1, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.should_skip_loop" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.squashed_dict" + } + ], + "hint_lines": [ + "assert ids.n_used_accesses == len(access_indices[key])" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.big_keys" + }, + { + "expr": "[cast(fp + (-8), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses" + }, + { + "expr": "[cast(fp + (-7), felt**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses_end_minus1" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_diff" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.first_value" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.key" + }, + { + "expr": "cast(ap + (-3), starkware.cairo.common.squash_dict.squash_dict_inner.LoopLocals*)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.last_loop_locals" + }, + { + "expr": "[cast(fp + (-9), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.range_check_ptr" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.remaining_accesses" + }, + { + "expr": "[cast(fp + 1, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.should_skip_loop" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.squashed_dict" + } + ], + "hint_lines": [ + "assert len(current_access_indices) == 0" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.big_keys" + }, + { + "expr": "[cast(fp + (-8), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses" + }, + { + "expr": "[cast(fp + (-7), felt**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses_end_minus1" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_diff" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.first_value" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.key" + }, + { + "expr": "cast(ap + (-3), starkware.cairo.common.squash_dict.squash_dict_inner.LoopLocals*)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.last_loop_locals" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.n_used_accesses" + }, + { + "expr": "cast([ap + (-3)] + 1, felt)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.range_check_ptr" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.remaining_accesses" + }, + { + "expr": "[cast(fp + 1, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.should_skip_loop" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.squashed_dict" + } + ], + "hint_lines": [ + "assert len(keys) == 0" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.big_keys" + }, + { + "expr": "[cast(fp + (-8), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses" + }, + { + "expr": "[cast(fp + (-7), felt**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses_end_minus1" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_diff" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.first_value" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.key" + }, + { + "expr": "cast(ap + (-3), starkware.cairo.common.squash_dict.squash_dict_inner.LoopLocals*)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.last_loop_locals" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.n_used_accesses" + }, + { + "expr": "[cast(ap, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.next_key" + }, + { + "expr": "cast([ap + (-3)] + 1, felt)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.range_check_ptr" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.remaining_accesses" + }, + { + "expr": "[cast(fp + 1, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.should_skip_loop" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.squashed_dict" + } + ], + "hint_lines": [ + "assert len(keys) > 0, 'No keys left but remaining_accesses > 0.'", + "ids.next_key = key = keys.pop()" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.big_keys" + }, + { + "expr": "[cast(fp + (-8), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses" + }, + { + "expr": "[cast(fp + (-7), felt**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses_end_minus1" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_diff" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.key" + }, + { + "expr": "[cast(fp + (-9), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.range_check_ptr" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.remaining_accesses" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.squashed_dict" + } + ], + "hint_lines": [ + "current_access_indices = sorted(access_indices[key])[::-1]", + "current_access_index = current_access_indices.pop()", + "memory[ids.range_check_ptr] = current_access_index" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + 2, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.big_keys" + }, + { + "expr": "[cast(fp + (-5), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.dict_accesses" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.dict_accesses_end" + }, + { + "expr": "[cast(fp + 1, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.first_key" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.n_accesses" + }, + { + "expr": "[cast(ap, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.ptr_diff" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.range_check_ptr" + }, + { + "expr": "[cast(fp + (-3), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.squashed_dict" + } + ], + "hint_lines": [ + "dict_access_size = ids.DictAccess.SIZE", + "address = ids.dict_accesses.address_", + "assert ids.ptr_diff % dict_access_size == 0, \\", + " 'Accesses array size must be divisible by DictAccess.SIZE'", + "n_accesses = ids.n_accesses", + "if '__squash_dict_max_size' in globals():", + " assert n_accesses <= __squash_dict_max_size, \\", + " f'squash_dict() can only be used with n_accesses<={__squash_dict_max_size}. ' \\", + " f'Got: n_accesses={n_accesses}.'", + "# A map from key to the list of indices accessing it.", + "access_indices = {}", + "for i in range(n_accesses):", + " key = memory[address + dict_access_size * i]", + " access_indices.setdefault(key, []).append(i)", + "# Descending list of keys.", + "keys = sorted(access_indices.keys(), reverse=True)", + "# Are the keys used bigger than range_check bound.", + "ids.big_keys = 1 if keys[0] >= range_check_builtin.bound else 0", + "ids.first_key = key = keys.pop()" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-5), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.dict.dict_write.dict_ptr" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.dict.dict_write.key" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.dict.dict_write.new_value" + } + ], + "hint_lines": [ + "dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)", + "dict_tracker.current_ptr += ids.DictAccess.SIZE", + "ids.dict_ptr.prev_value = dict_tracker.data[ids.key]", + "dict_tracker.data[ids.key] = ids.new_value" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.dict.dict_read.dict_ptr" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.dict.dict_read.key" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "starkware.cairo.common.dict.dict_read.value" + } + ], + "hint_lines": [ + "dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)", + "dict_tracker.current_ptr += ids.DictAccess.SIZE", + "ids.value = dict_tracker.data[ids.key]" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-7), starkware.cairo.common.cairo_builtins.SignatureBuiltin**)]", + "name": "starkware.cairo.common.signature.verify_ecdsa_signature.ecdsa_ptr" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.signature.verify_ecdsa_signature.message" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.signature.verify_ecdsa_signature.public_key" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.signature.verify_ecdsa_signature.signature_r" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.signature.verify_ecdsa_signature.signature_s" + } + ], + "hint_lines": [ + "ecdsa_builtin.add_signature(ids.ecdsa_ptr.address_, (ids.signature_r, ids.signature_s))" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.math.assert_le_250_bit.a" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.math.assert_le_250_bit.b" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.math.assert_le_250_bit.diff" + }, + { + "expr": "[cast([fp + (-5)] + 1, felt*)]", + "name": "starkware.cairo.common.math.assert_le_250_bit.high" + }, + { + "expr": "[cast([fp + (-5)], felt*)]", + "name": "starkware.cairo.common.math.assert_le_250_bit.low" + }, + { + "expr": "cast([fp + (-5)] + 2, felt)", + "name": "starkware.cairo.common.math.assert_le_250_bit.range_check_ptr" + } + ], + "hint_lines": [ + "from starkware.cairo.common.math_utils import as_int", + "", + "# Soundness checks.", + "assert range_check_builtin.bound == 2**128", + "assert ids.UPPER_BOUND == ids.HIGH_PART_SHIFT * range_check_builtin.bound", + "", + "# Correctness check.", + "diff = as_int(ids.diff, PRIME)", + "values_msg = f'(a={as_int(ids.a, PRIME)}, b={as_int(ids.b, PRIME)}).'", + "assert diff < ids.UPPER_BOUND, f'(b - a)={diff} is outside of the valid range. {values_msg}'", + "assert PRIME - ids.UPPER_BOUND > (ids.HIGH_PART_SHIFT + 1) * range_check_builtin.bound", + "", + "assert diff >= 0, f'(b - a)={diff} < 0. {values_msg}'", + "", + "# Calculation for the assertion.", + "ids.high = ids.diff // ids.HIGH_PART_SHIFT", + "ids.low = ids.diff % ids.HIGH_PART_SHIFT" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast([fp + (-6)] + 1, felt*)]", + "name": "starkware.cairo.common.math.signed_div_rem.biased_q" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.math.signed_div_rem.bound" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.math.signed_div_rem.div" + }, + { + "expr": "cast([[fp + (-6)] + 1] - [fp + (-3)], felt)", + "name": "starkware.cairo.common.math.signed_div_rem.q" + }, + { + "expr": "[cast([fp + (-6)], felt*)]", + "name": "starkware.cairo.common.math.signed_div_rem.r" + }, + { + "expr": "cast([fp + (-6)] + 2, felt)", + "name": "starkware.cairo.common.math.signed_div_rem.range_check_ptr" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.math.signed_div_rem.value" + } + ], + "hint_lines": [ + "from starkware.cairo.common.math_utils import as_int, assert_integer", + "", + "assert_integer(ids.div)", + "assert 0 < ids.div <= PRIME // range_check_builtin.bound, \\", + " f'div={hex(ids.div)} is out of the valid range.'", + "", + "assert_integer(ids.bound)", + "assert ids.bound <= range_check_builtin.bound // 2, \\", + " f'bound={hex(ids.bound)} is out of the valid range.'", + "", + "int_value = as_int(ids.value, PRIME)", + "q, ids.r = divmod(int_value, ids.div)", + "", + "assert -ids.bound <= q < ids.bound, \\", + " f'{int_value} / {ids.div} = {q} is out of the range [{-ids.bound}, {ids.bound}).'", + "", + "ids.biased_q = q + ids.bound" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast([fp + (-4)] + 1, felt*)]", + "name": "starkware.cairo.common.math.split_felt.high" + }, + { + "expr": "[cast([fp + (-4)], felt*)]", + "name": "starkware.cairo.common.math.split_felt.low" + }, + { + "expr": "cast([fp + (-4)] + 2, felt)", + "name": "starkware.cairo.common.math.split_felt.range_check_ptr" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.math.split_felt.value" + } + ], + "hint_lines": [ + "from starkware.cairo.common.math_utils import assert_integer", + "assert PRIME < 2**256", + "assert_integer(ids.value)", + "ids.low = ids.value & ((1 << 128) - 1)", + "ids.high = ids.value >> 128" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.math.assert_nn.a" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.math.assert_nn.range_check_ptr" + } + ], + "hint_lines": [ + "from starkware.cairo.common.math_utils import assert_integer", + "assert_integer(ids.a)", + "assert 0 <= ids.a % PRIME < range_check_builtin.bound, f'a = {ids.a} is out of range.'" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.math.assert_lt_felt.a" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.math.assert_lt_felt.b" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.math.assert_lt_felt.range_check_ptr" + } + ], + "hint_lines": [ + "from starkware.cairo.common.math_utils import assert_integer", + "assert_integer(ids.a)", + "assert_integer(ids.b)", + "assert (ids.a % PRIME) < (ids.b % PRIME), \\", + " f'a = {ids.a % PRIME} is not less than b = {ids.b % PRIME}.'" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.math.assert_le_felt.a" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.math.assert_le_felt.b" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.math.assert_le_felt.range_check_ptr" + } + ], + "hint_lines": [ + "from starkware.cairo.common.math_utils import assert_integer", + "assert_integer(ids.a)", + "assert_integer(ids.b)", + "assert (ids.a % PRIME) <= (ids.b % PRIME), \\", + " f'a = {ids.a % PRIME} is not less than or equal to b = {ids.b % PRIME}.'" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.math.unsigned_div_rem.div" + }, + { + "expr": "[cast([fp + (-5)] + 1, felt*)]", + "name": "starkware.cairo.common.math.unsigned_div_rem.q" + }, + { + "expr": "[cast([fp + (-5)], felt*)]", + "name": "starkware.cairo.common.math.unsigned_div_rem.r" + }, + { + "expr": "cast([fp + (-5)] + 2, felt)", + "name": "starkware.cairo.common.math.unsigned_div_rem.range_check_ptr" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.math.unsigned_div_rem.value" + } + ], + "hint_lines": [ + "from starkware.cairo.common.math_utils import assert_integer", + "assert_integer(ids.div)", + "assert 0 < ids.div <= PRIME // range_check_builtin.bound, \\", + " f'div={hex(ids.div)} is out of the valid range.'", + "ids.q, ids.r = divmod(ids.value, ids.div)" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.math.assert_not_zero.value" + } + ], + "hint_lines": [ + "from starkware.cairo.common.math_utils import assert_integer", + "assert_integer(ids.value)", + "assert ids.value % PRIME != 0, f'assert_not_zero failed: {ids.value} = 0.'" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.math.abs_value.range_check_ptr" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.math.abs_value.value" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.math.sign.range_check_ptr" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.math.sign.value" + } + ], + "hint_lines": [ + "from starkware.cairo.common.math_utils import is_positive", + "memory[ap] = 1 if is_positive(", + " value=ids.value, prime=PRIME, rc_bound=range_check_builtin.bound) else 0" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.math.assert_not_equal.a" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.math.assert_not_equal.b" + } + ], + "hint_lines": [ + "from starkware.cairo.lang.vm.relocatable import RelocatableValue", + "both_ints = isinstance(ids.a, int) and isinstance(ids.b, int)", + "both_relocatable = (", + " isinstance(ids.a, RelocatableValue) and isinstance(ids.b, RelocatableValue) and", + " ids.a.segment_index == ids.b.segment_index)", + "assert both_ints or both_relocatable, \\", + " f'assert_not_equal failed: non-comparable values: {ids.a}, {ids.b}.'", + "assert (ids.a - ids.b) % PRIME != 0, f'assert_not_equal failed: {ids.a} = {ids.b}.'" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.starknet.core.storage.storage.storage_read.address" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.starknet.core.storage.storage.storage_read.dict_ptr" + }, + { + "expr": "[cast(ap + (-1), starkware.starknet.core.storage.storage.Storage**)]", + "name": "starkware.starknet.core.storage.storage.storage_read.storage_ptr" + } + ], + "hint_lines": [ + "ids.dict_ptr.prev_value = __storage.read(address=ids.dict_ptr.key)" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.starknet.core.storage.storage.storage_write.address" + }, + { + "expr": "[cast(fp + (-5), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.starknet.core.storage.storage.storage_write.dict_ptr" + }, + { + "expr": "cast([fp + (-5)] + 3, starkware.starknet.core.storage.storage.Storage*)", + "name": "starkware.starknet.core.storage.storage.storage_write.storage_ptr" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.starknet.core.storage.storage.storage_write.value" + } + ], + "hint_lines": [ + "ids.dict_ptr.prev_value = __storage.read(address=ids.dict_ptr.key)", + "__storage.write(address=ids.dict_ptr.key, value=ids.dict_ptr.new_value)" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(ap + 1, starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.access" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.big_keys" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.current_access_index" + }, + { + "expr": "[cast(fp + (-8), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses" + }, + { + "expr": "[cast(fp + (-7), felt**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses_end_minus1" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_diff" + }, + { + "expr": "[cast(ap, starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.first_access" + }, + { + "expr": "cast(ap, starkware.cairo.common.squash_dict.squash_dict_inner.LoopLocals*)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.first_loop_locals" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.first_value" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.key" + }, + { + "expr": "cast(ap + 4, starkware.cairo.common.squash_dict.squash_dict_inner.LoopLocals*)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.loop_locals" + }, + { + "expr": "cast(ap, starkware.cairo.common.squash_dict.squash_dict_inner.LoopTemps*)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.loop_temps" + }, + { + "expr": "cast(ap + (-3), starkware.cairo.common.squash_dict.squash_dict_inner.LoopLocals*)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.prev_loop_locals" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.ptr_delta" + }, + { + "expr": "[cast(fp + (-9), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.range_check_ptr" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.remaining_accesses" + }, + { + "expr": "[cast(fp + 1, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.should_skip_loop" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.squashed_dict" + } + ], + "hint_lines": [ + "ids.loop_temps.should_continue = 1 if current_access_indices else 0" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.big_keys" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.current_access_index" + }, + { + "expr": "[cast(fp + (-8), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses" + }, + { + "expr": "[cast(fp + (-7), felt**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses_end_minus1" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_diff" + }, + { + "expr": "[cast(ap, starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.first_access" + }, + { + "expr": "cast(ap, starkware.cairo.common.squash_dict.squash_dict_inner.LoopLocals*)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.first_loop_locals" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.first_value" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.key" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.ptr_delta" + }, + { + "expr": "[cast(fp + (-9), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.range_check_ptr" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.remaining_accesses" + }, + { + "expr": "[cast(fp + 1, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.should_skip_loop" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.squashed_dict" + } + ], + "hint_lines": [ + "ids.should_skip_loop = 0 if current_access_indices else 1" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.default_dict.default_dict_new.default_value" + } + ], + "hint_lines": [ + "if '__dict_manager' not in globals():", + " from starkware.cairo.common.dict import DictManager", + " __dict_manager = DictManager()", + "", + "memory[ap] = __dict_manager.new_default_dict(segments, ids.default_value)" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "if '__dict_manager' not in globals():", + " from starkware.cairo.common.dict import DictManager", + " __dict_manager = DictManager()", + "", + "memory[ap] = __dict_manager.new_dict(segments, initial_dict)", + "del initial_dict" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.math_cmp.is_le_felt.a" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.math_cmp.is_le_felt.b" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.math_cmp.is_le_felt.range_check_ptr" + } + ], + "hint_lines": [ + "memory[ap] = 0 if (ids.a % PRIME) <= (ids.b % PRIME) else 1" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.math_cmp.is_nn.a" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.math_cmp.is_nn.range_check_ptr" + } + ], + "hint_lines": [ + "memory[ap] = 0 if 0 <= ((-ids.a - 1) % PRIME) < range_check_builtin.bound else 1" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.math_cmp.is_nn.a" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.math_cmp.is_nn.range_check_ptr" + } + ], + "hint_lines": [ + "memory[ap] = 0 if 0 <= (ids.a % PRIME) < range_check_builtin.bound else 1" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "memory[ap] = segments.add()" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.memcpy.memcpy.__temp43" + }, + { + "expr": "[cast(ap, felt*)]", + "name": "starkware.cairo.common.memcpy.memcpy.continue_copying" + }, + { + "expr": "[cast(fp + (-5), felt**)]", + "name": "starkware.cairo.common.memcpy.memcpy.dst" + }, + { + "expr": "[cast(ap + (-2), starkware.cairo.common.memcpy.memcpy.LoopFrame*)]", + "name": "starkware.cairo.common.memcpy.memcpy.frame" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.memcpy.memcpy.len" + }, + { + "expr": "cast(ap + 1, starkware.cairo.common.memcpy.memcpy.LoopFrame*)", + "name": "starkware.cairo.common.memcpy.memcpy.next_frame" + }, + { + "expr": "[cast(fp + (-4), felt**)]", + "name": "starkware.cairo.common.memcpy.memcpy.src" + } + ], + "hint_lines": [ + "n -= 1", + "ids.continue_copying = 1 if n > 0 else 0" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.big_keys" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.current_access_index" + }, + { + "expr": "[cast(fp + (-8), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses" + }, + { + "expr": "[cast(fp + (-7), felt**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_accesses_end_minus1" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.dict_diff" + }, + { + "expr": "[cast(ap, starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.first_access" + }, + { + "expr": "cast(ap, starkware.cairo.common.squash_dict.squash_dict_inner.LoopLocals*)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.first_loop_locals" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.first_value" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.key" + }, + { + "expr": "cast(ap + 4, starkware.cairo.common.squash_dict.squash_dict_inner.LoopLocals*)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.loop_locals" + }, + { + "expr": "cast(ap, starkware.cairo.common.squash_dict.squash_dict_inner.LoopTemps*)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.loop_temps" + }, + { + "expr": "cast(ap + (-3), starkware.cairo.common.squash_dict.squash_dict_inner.LoopLocals*)", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.prev_loop_locals" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.ptr_delta" + }, + { + "expr": "[cast(fp + (-9), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.range_check_ptr" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.remaining_accesses" + }, + { + "expr": "[cast(fp + 1, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.should_skip_loop" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict_inner.squashed_dict" + } + ], + "hint_lines": [ + "new_access_index = current_access_indices.pop()", + "ids.loop_temps.index_delta_minus1 = new_access_index - current_access_index - 1", + "current_access_index = new_access_index" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-5), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.dict_accesses" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.dict_accesses_end" + }, + { + "expr": "[cast(ap, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.ptr_diff" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.range_check_ptr" + }, + { + "expr": "[cast(fp + (-3), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.squashed_dict" + } + ], + "hint_lines": [ + "vm_enter_scope()" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-5), felt**)]", + "name": "starkware.cairo.common.memcpy.memcpy.dst" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.memcpy.memcpy.len" + }, + { + "expr": "[cast(fp + (-4), felt**)]", + "name": "starkware.cairo.common.memcpy.memcpy.src" + } + ], + "hint_lines": [ + "vm_enter_scope({'n': ids.len})" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(ap + (-1), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.dict.dict_squash.__temp28" + }, + { + "expr": "[cast(fp + (-3), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.dict.dict_squash.dict_accesses_end" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.dict.dict_squash.dict_accesses_start" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.dict.dict_squash.range_check_ptr" + }, + { + "expr": "[cast(fp, starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.dict.dict_squash.squashed_dict_start" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "starkware.cairo.common.memcpy.memcpy.__temp43" + }, + { + "expr": "[cast(ap, felt*)]", + "name": "starkware.cairo.common.memcpy.memcpy.continue_copying" + }, + { + "expr": "[cast(fp + (-5), felt**)]", + "name": "starkware.cairo.common.memcpy.memcpy.dst" + }, + { + "expr": "[cast(ap + (-2), starkware.cairo.common.memcpy.memcpy.LoopFrame*)]", + "name": "starkware.cairo.common.memcpy.memcpy.frame" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.memcpy.memcpy.len" + }, + { + "expr": "cast(ap + 1, starkware.cairo.common.memcpy.memcpy.LoopFrame*)", + "name": "starkware.cairo.common.memcpy.memcpy.next_frame" + }, + { + "expr": "[cast(fp + (-4), felt**)]", + "name": "starkware.cairo.common.memcpy.memcpy.src" + }, + { + "expr": "[cast(fp + 2, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.big_keys" + }, + { + "expr": "[cast(fp + (-5), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.dict_accesses" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.dict_accesses_end" + }, + { + "expr": "[cast(fp + 1, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.first_key" + }, + { + "expr": "[cast(ap - 1 + (-1), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.n_accesses" + }, + { + "expr": "[cast(ap - 5, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.ptr_diff" + }, + { + "expr": "[cast(ap, felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.ptr_diff" + }, + { + "expr": "[cast(ap + (-2), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.range_check_ptr" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.range_check_ptr" + }, + { + "expr": "[cast(ap + (-1), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.squashed_dict" + }, + { + "expr": "[cast(fp + (-3), starkware.cairo.common.dict_access.DictAccess**)]", + "name": "starkware.cairo.common.squash_dict.squash_dict.squashed_dict" + } + ], + "hint_lines": [ + "vm_exit_scope()" + ] + } + ] +} diff --git a/src/starkware/starknet/services/CMakeLists.txt b/src/starkware/starknet/services/CMakeLists.txt new file mode 100644 index 00000000..53d257b0 --- /dev/null +++ b/src/starkware/starknet/services/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(api) diff --git a/src/starkware/starknet/services/api/CMakeLists.txt b/src/starkware/starknet/services/api/CMakeLists.txt new file mode 100644 index 00000000..f6890e55 --- /dev/null +++ b/src/starkware/starknet/services/api/CMakeLists.txt @@ -0,0 +1,16 @@ +add_subdirectory(feeder_gateway) +add_subdirectory(gateway) + +python_lib(starknet_contract_definition_lib + PREFIX starkware/starknet/services/api + + FILES + contract_definition.py + + LIBS + cairo_compile_lib + starknet_definitions_lib + starkware_utils_lib + pip_marshmallow + pip_marshmallow_dataclass +) diff --git a/src/starkware/starknet/services/api/__init__.py b/src/starkware/starknet/services/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/starkware/starknet/services/api/contract_definition.py b/src/starkware/starknet/services/api/contract_definition.py new file mode 100644 index 00000000..31763f4e --- /dev/null +++ b/src/starkware/starknet/services/api/contract_definition.py @@ -0,0 +1,40 @@ +import dataclasses +from dataclasses import field +from typing import Any, ClassVar, List, Optional, Type + +import marshmallow +import marshmallow_dataclass + +from starkware.cairo.lang.compiler.program import Program +from starkware.starknet.definitions import fields +from starkware.starknet.definitions.error_codes import StarknetErrorCode +from starkware.starkware_utils.error_handling import stark_assert +from starkware.starkware_utils.validated_dataclass import ( + ValidatedDataclass, ValidatedMarshmallowDataclass) + + +@dataclasses.dataclass(frozen=True) +class ContractEntryPoint(ValidatedDataclass): + # A field element that encodes the signature of the called function. + selector: int = field(metadata=fields.entry_point_selector_metadata) + # The offset of the instruction that should be called within the contract bytecode. + offset: int = field(metadata=fields.entry_point_offset_metadata) + + +@marshmallow_dataclass.dataclass(frozen=True) +class ContractDefinition(ValidatedMarshmallowDataclass): + """ + Represents a contract in the StarkNet network. + """ + program: Program + entry_points: List[ContractEntryPoint] + abi: Optional[List[Any]] = None + Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema + + def __post_init__(self): + super().__post_init__() + + stark_assert( + len(self.entry_points) == len(set([ep.selector for ep in self.entry_points])), + code=StarknetErrorCode.MULTIPLE_ENTRY_POINTS_MATCH_SELECTOR, + message='Entry points must be unique.') diff --git a/src/starkware/starknet/services/api/feeder_gateway/CMakeLists.txt b/src/starkware/starknet/services/api/feeder_gateway/CMakeLists.txt new file mode 100644 index 00000000..d747cf16 --- /dev/null +++ b/src/starkware/starknet/services/api/feeder_gateway/CMakeLists.txt @@ -0,0 +1,10 @@ +python_lib(starknet_feeder_gateway_client_lib + PREFIX starkware/starknet/services/api/feeder_gateway + + FILES + feeder_gateway_client.py + + LIBS + everest_feeder_gateway_client_lib + starknet_transaction_lib +) diff --git a/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py b/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py new file mode 100644 index 00000000..e97fd946 --- /dev/null +++ b/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py @@ -0,0 +1,42 @@ +import json +from typing import Any, Dict, List, Optional + +from services.everest.api.feeder_gateway.feeder_gateway_client import EverestFeederGatewayClient +from starkware.starknet.services.api.gateway.transaction import InvokeFunction + + +class FeederGatewayClient(EverestFeederGatewayClient): + """ + A client class for the StarkNet FeederGateway. + """ + + async def call_contract( + self, invoke_tx: InvokeFunction, + block_id: Optional[int] = None) -> Dict[str, List[int]]: + raw_response = await self._send_request( + send_method='POST', + uri=f'/call_contract?blockId={json.dumps(block_id)}', data=invoke_tx.dumps()) + return json.loads(raw_response) + + async def get_block(self, block_id: Optional[int] = None) -> Dict[str, Any]: + raw_response = await self._send_request( + send_method='GET', uri=f'/get_block?blockId={json.dumps(block_id)}') + return json.loads(raw_response) + + async def get_code(self, contract_address: int, block_id: Optional[int] = None) -> List[int]: + uri = f'/get_code?contractAddress={hex(contract_address)}&blockId={json.dumps(block_id)}' + raw_response = await self._send_request(send_method='GET', uri=uri) + return json.loads(raw_response) + + async def get_storage_at( + self, contract_address: int, key: int, block_id: Optional[int] = None) -> int: + uri = ( + f'/get_storage_at?contractAddress={hex(contract_address)}&key={key}&' + f'blockId={json.dumps(block_id)}') + raw_response = await self._send_request(send_method='GET', uri=uri) + return json.loads(raw_response) + + async def get_transaction_status(self, tx_id: int) -> str: + raw_response = await self._send_request( + send_method='GET', uri=f'/get_transaction_status?transactionId={tx_id}') + return json.loads(raw_response) diff --git a/src/starkware/starknet/services/api/gateway/CMakeLists.txt b/src/starkware/starknet/services/api/gateway/CMakeLists.txt new file mode 100644 index 00000000..a4266a8b --- /dev/null +++ b/src/starkware/starknet/services/api/gateway/CMakeLists.txt @@ -0,0 +1,27 @@ +python_lib(starknet_transaction_lib + PREFIX starkware/starknet/services/api/gateway + + FILES + transaction.py + + LIBS + everest_definitions_lib + everest_transaction_lib + pip_marshmallow + starknet_contract_definition_lib + starknet_definitions_lib + pip_marshmallow_dataclass + pip_marshmallow_enum + pip_marshmallow_oneofschema +) + +python_lib(starknet_gateway_client_lib + PREFIX starkware/starknet/services/api/gateway + + FILES + gateway_client.py + + LIBS + everest_gateway_client_lib + starknet_transaction_lib +) diff --git a/src/starkware/starknet/services/api/gateway/gateway_client.py b/src/starkware/starknet/services/api/gateway/gateway_client.py new file mode 100644 index 00000000..5767cce3 --- /dev/null +++ b/src/starkware/starknet/services/api/gateway/gateway_client.py @@ -0,0 +1,15 @@ +import json +from typing import Dict + +from services.everest.api.gateway.gateway_client import EverestGatewayClient +from starkware.starknet.services.api.gateway.transaction import Transaction + + +class GatewayClient(EverestGatewayClient): + """ + A client class for the StarkNet Gateway. + """ + async def add_transaction(self, tx: Transaction) -> Dict[str, int]: + raw_response = await self._send_request( + send_method='POST', uri='/add_transaction', data=Transaction.Schema().dumps(obj=tx)) + return json.loads(raw_response) diff --git a/src/starkware/starknet/services/api/gateway/transaction.py b/src/starkware/starknet/services/api/gateway/transaction.py new file mode 100644 index 00000000..bf2e6bef --- /dev/null +++ b/src/starkware/starknet/services/api/gateway/transaction.py @@ -0,0 +1,116 @@ +import base64 +import dataclasses +import gzip +import json +from abc import abstractmethod +from dataclasses import field +from typing import Any, ClassVar, Dict, List, Type + +import marshmallow +import marshmallow.decorators +import marshmallow_dataclass +from marshmallow_oneofschema import OneOfSchema + +from services.everest.api.gateway.transaction import ( + EverestAddTransactionRequest, EverestTransaction) +from services.everest.definitions import fields as everest_fields +from starkware.starknet.definitions import fields +from starkware.starknet.definitions.transaction_type import TransactionType +from starkware.starknet.services.api.contract_definition import ContractDefinition + + +class Transaction(EverestTransaction): + """ + StarkNet transaction base class. + """ + + @property + @classmethod + @abstractmethod + def tx_type(cls) -> TransactionType: + """ + Returns the corresponding TransactionType enum. Used in TransacactionSchema. + Subclasses should define it as a class variable. + """ + + +@marshmallow_dataclass.dataclass(frozen=True) +class Deploy(Transaction): + """ + Represents a transaction in the StarkNet network that is a deployment of a StarkNet contract. + """ + + contract_address: int = field(metadata=fields.contract_address_metadata) + contract_definition: ContractDefinition + + # Class variables. + tx_type: ClassVar[TransactionType] = TransactionType.DEPLOY + + @marshmallow.decorators.post_dump + def compress_program(self, data: Dict[str, Any], many: bool, **kwargs) -> Dict[str, Any]: + full_program = json.dumps(data['contract_definition']['program']) + compressed_program = gzip.compress(data=full_program.encode('ascii')) + compressed_program = base64.b64encode(compressed_program) + data['contract_definition']['program'] = compressed_program.decode('ascii') + return data + + @marshmallow.decorators.pre_load + def decompress_program(self, data: Dict[str, Any], many: bool, **kwargs) -> Dict[str, Any]: + compressed_program: str = data['contract_definition']['program'] + compressed_program_bytes = base64.b64decode(compressed_program.encode('ascii')) + decompressed_program = gzip.decompress(data=compressed_program_bytes) + data['contract_definition']['program'] = json.loads(decompressed_program.decode('ascii')) + return data + + def _remove_debug_info(self) -> 'Deploy': + """ + Sets debug_info in the Cairo contract program to None. + Returns an altered Deploy instance. + """ + altered_program = dataclasses.replace(self.contract_definition.program, debug_info=None) + altered_contract_definition = dataclasses.replace( + self.contract_definition, program=altered_program) + return dataclasses.replace(self, contract_definition=altered_contract_definition) + + +@marshmallow_dataclass.dataclass(frozen=True) +class InvokeFunction(Transaction): + """ + Represents a transaction in the StarkNet network that is an invocation of a Cairo contract + function. + """ + + contract_address: int = field(metadata=fields.contract_address_metadata) + # A field element that encodes the signature of the called function. + entry_point_selector: int = field(metadata=fields.entry_point_selector_metadata) + calldata: List[int] = field(metadata=fields.call_data_metadata) + + # Class variables. + tx_type: ClassVar[TransactionType] = TransactionType.INVOKE_FUNCTION + + +class TransactionSchema(OneOfSchema): + """ + Schema for transaction. + OneOfSchema adds a "type" field. + + Allows the use of load/dump of different transaction type data directly via the + Transaction class (e.g., Transaction.load(invoke_function_dict), where + {"type": "INVOKE_FUNCTION"} is in invoke_function_dict, will produce an InvokeFunction object). + """ + type_schemas: Dict[str, Type[marshmallow.Schema]] = { + TransactionType.DEPLOY.name: Deploy.Schema, + TransactionType.INVOKE_FUNCTION.name: InvokeFunction.Schema, + } + + def get_obj_type(self, obj: Transaction) -> str: + return obj.tx_type.name + + +Transaction.Schema = TransactionSchema + + +@marshmallow_dataclass.dataclass(frozen=True) +class AddTransactionRequest(EverestAddTransactionRequest): + tx: Transaction + tx_id: int = field(metadata=everest_fields.tx_id_field_metadata) diff --git a/src/starkware/starkware_utils/CMakeLists.txt b/src/starkware/starkware_utils/CMakeLists.txt new file mode 100644 index 00000000..7eda5b62 --- /dev/null +++ b/src/starkware/starkware_utils/CMakeLists.txt @@ -0,0 +1,29 @@ +python_lib(starkware_serializability_utils_lib + PREFIX starkware/starkware_utils + + FILES + serializable.py +) + +python_lib(starkware_utils_lib + PREFIX starkware/starkware_utils + + FILES + custom_raising_dict.py + error_handling.py + field_validators.py + marshmallow_dataclass_fields.py + validated_dataclass.py + validated_fields.py + ${STARKWARE_UTILS_LIBS_ADDITIONAL_FILES} + + LIBS + starkware_python_utils_lib + starkware_serializability_utils_lib + pip_frozendict + pip_marshmallow + pip_marshmallow_dataclass + pip_typeguard + pip_web3 + ${STARKWARE_UTILS_LIBS_ADDITIONAL_LIBS} +) diff --git a/src/starkware/starkware_utils/custom_raising_dict.py b/src/starkware/starkware_utils/custom_raising_dict.py new file mode 100644 index 00000000..f3a6259f --- /dev/null +++ b/src/starkware/starkware_utils/custom_raising_dict.py @@ -0,0 +1,71 @@ +from abc import ABC +from collections import UserDict +from typing import Generic, Type, TypeVar + +from frozendict import frozendict + +KT = TypeVar('KT') +VT = TypeVar('VT') + + +class CustomRaisingDict(ABC, UserDict, Generic[KT, VT]): + """ + A dictionary that raises a custom exception. + The exception's type must be a subclass of KeyError. + """ + + @property + @classmethod + def exception_type(cls) -> Type[Exception]: + raise NotImplementedError() + + @classmethod + def __init_subclass__(cls, exception_type: Type[Exception], **kwargs): + super().__init_subclass__(**kwargs) # type: ignore[call-arg] + + assert issubclass(exception_type, KeyError), 'Exception type must subclass KeyError.' + cls.exception_type = exception_type # type: ignore + + def __getitem__(self, key: KT) -> VT: + try: + return super().__getitem__(key) + except KeyError: + raise self.exception_type(key) from None + + def __delitem__(self, key: KT): + try: + super().__delitem__(key) + except KeyError: + raise self.exception_type(key) from None + + +class CustomRaisingFrozenDict(frozendict, Generic[KT, VT]): + """ + A frozen CustomRaisingDict. + For a nonexistent key k in D, D[k] will raise a custom exception; del D[k] will raise a + TypeError (as in frozendict). + """ + + @classmethod + def __init_subclass__(cls, exception_type: Type[Exception], **kwargs): + super().__init_subclass__(**kwargs) # type: ignore[call-arg] + + class _CustomRaisingFrozenDict(CustomRaisingDict[KT, VT], exception_type=exception_type): + pass + + _CustomRaisingFrozenDict.__name__ = _CustomRaisingFrozenDict.__qualname__ = \ + 'CustomRaisingFrozenDict' + + cls.dict_cls = _CustomRaisingFrozenDict + + def __hash__(self): + """ + Calculates the hash of the dictionary, without taking its order into account. + The type is concatenated so that the hash will not equal the hash of the tuple of items. + This is implemented in order to avoid using frozendict's __hash__, which does not use + cls.dict_cls.items(). + """ + if self._hash is None: + self._hash = hash((self.dict_cls, frozenset(self.items()))) + + return self._hash diff --git a/src/starkware/starkware_utils/error_handling.py b/src/starkware/starkware_utils/error_handling.py new file mode 100644 index 00000000..a50c6b83 --- /dev/null +++ b/src/starkware/starkware_utils/error_handling.py @@ -0,0 +1,209 @@ +import operator +from enum import Enum, auto +from typing import Any, Dict, Optional + +symbol_to_function = {'!=': operator.ne, '==': operator.eq, '>': operator.gt, '>=': operator.ge} + + +class ErrorCode(Enum): + """ + Base class of all error code enums. + Do not add enum members to this class, only functionality. + See: https://docs.python.org/3/library/enum.html#restricted-enum-subclassing. + """ + + +class StarkErrorCode(ErrorCode): + #: Api function temporarily disabled. + API_FUNCTION_TEMPORARILY_DISABLED = 0 + #: Batch creation failure; batch currently cannot be created. + BATCH_CREATION_FAILURE = auto() + #: Batch is full; there will be no additional attempt to insert any transactions. + BATCH_FULL = auto() + #: Batch not ready to be created; does not indicate an error. + BATCH_NOT_READY = auto() + #: Order amount exceeds capacity. + CONFLICTING_ORDER_AMOUNTS = auto() + #: Fact not registered in fact registry. + FACT_NOT_REGISTERED = auto() + #: Not enough onchain balance to complete deposit. + INSUFFICIENT_ONCHAIN_BALANCE = auto() + #: Invalid batch ID. + INVALID_BATCH_ID = auto() + #: Invalid committee claim hash. + INVALID_CLAIM_HASH = auto() + #: Invalid committee member key. + INVALID_COMMITTEE_MEMBER = auto() + #: StarkEx contracts information missing or corrupt. + INVALID_CONTRACT_ADDRESS = auto() + #: Invalid response from a contract (for example, Infura too many requests). + INVALID_CONTRACT_RESPONSE = auto() + #: StarkEx deployment information missing or corrupt. + INVALID_DEPLOYMENT_INFO = auto() + #: Invalid eth address. + INVALID_ETH_ADDRESS = auto() + #: Fact is not 32 bytes length. + INVALID_FACT = auto() + #: Fee taken is too high. + INVALID_FEE_TAKEN = auto() + #: Invalid order ID. + INVALID_ORDER_ID = auto() + #: Invalid order type. + INVALID_ORDER_TYPE = auto() + #: Invalid HTTP request. + INVALID_REQUEST = auto() + #: Invalid HTTP request parameters. + INVALID_REQUEST_PARAMETERS = auto() + #: Settlement trade amounts mismatch. + INVALID_SETTLEMENT_INFO = auto() + #: Settlement trade ratio not satisfied. + INVALID_SETTLEMENT_RATIO = auto() + #: Mismatching tokens for orders in settlement. + INVALID_SETTLEMENT_TOKENS = auto() + #: Invalid order signature. + INVALID_SIGNATURE = auto() + #: Invalid transaction. + INVALID_TRANSACTION = auto() + #: Invalid transaction ID. + INVALID_TRANSACTION_ID = auto() + #: Invalid vault. + INVALID_VAULT = auto() + #: Malformed request. + MALFORMED_REQUEST = auto() + #: Pipeline object is missing because it was migrated from an older version object. + MIGRATED_PIPELINE_OBJECT_MISSING = auto() + #: One of the fee objects is missing while the other exists. + MISSING_FEE_OBJECT = auto() + #: The order is expired. + ORDER_OVERDUE = auto() + #: Positive amount value is out of range. + OUT_OF_RANGE_POSITIVE_AMOUNT = auto() + #: Amount value is out of range. + OUT_OF_RANGE_AMOUNT = auto() + #: Vault balance is out of range. + OUT_OF_RANGE_BALANCE = auto() + #: Batch ID value is out of range. + OUT_OF_RANGE_BATCH_ID = auto() + #: Expiration timestamp value is out of range. + OUT_OF_RANGE_EXPIRATION_TIMESTAMP = auto() + #: Nonce value is out of range. + OUT_OF_RANGE_NONCE = auto() + #: Oracle price quorum value is out of range. + OUT_OF_RANGE_ORACLE_PRICE_QUORUM = auto() + #: Order ID value is out of range. + OUT_OF_RANGE_ORDER_ID = auto() + #: Public key (Stark key) value is out of range. + OUT_OF_RANGE_PUBLIC_KEY = auto() + #: Signature subfield is out of range. + OUT_OF_RANGE_SIGNATURE_SUBFIELD = auto() + #: Token ID value is out of range. + OUT_OF_RANGE_TOKEN_ID = auto() + #: Vault ID value is out of range. + OUT_OF_RANGE_VAULT_ID = auto() + #: Alternative transaction requested before for this transaction. Transaction is now valid. + REPLACED_BEFORE = auto() + #: Failed response for alternative transaction request. + REQUEST_FAILED = auto() + #: Object schema validation failed. + SCHEMA_VALIDATION_ERROR = auto() + #: Transaction received successfully by the gateway. + TRANSACTION_PENDING = auto() + TRANSACTION_RECEIVED = auto() + + +class WebFriendlyException(Exception): + """ + Base class to exception classes that are exposed to the user, usually in a HTTP response body. + """ + + def __init__(self, status_code: int, body: Dict[str, Any]): + self.status_code = status_code + self.body = body + super().__init__(status_code, body) + + +class StarkException(WebFriendlyException): + """ + Base class to exceptions classes representing flows under the user's control (for example, + an invalid transaction). + """ + + def __init__(self, code, message: Optional[str] = None): + self.code = code + self.message = message + super().__init__(status_code=500, body={'code': code, 'message': message}) + + def __repr__(self) -> str: + return f'{type(self).__name__}(code={self.code}, message={self.message})' + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, StarkException): + raise NotImplementedError + + return self.code == other.code and self.message == other.message + + +def stark_assert(expr: bool, code, message: Optional[str] = None): + """ + Verifies that the given expression is True. If not, raises a StarkException with the given + code and message. + """ + if not expr: + raise StarkException(code=code, message=message) + + +def stark_assert_eq(exp_val, actual_val, code, message: Optional[str] = None): + """ + Verifies that the expected value is equal to the actual value, raising a StarkException with + the appropriate code and message, where the expected and actual values are added to the message. + """ + _stark_assert_not_symbol(exp_val, actual_val, symbol='!=', code=code, message=message) + + +def stark_assert_ne(exp_val, actual_val, code, message: Optional[str] = None): + """ + Verifies that the expected value is not equal to the actual value, raising a StarkException + with the appropriate code and message, where the expected and actual values are added to the + message. + """ + _stark_assert_not_symbol(exp_val, actual_val, symbol='==', code=code, message=message) + + +def stark_assert_le(exp_val, actual_val, code, message: Optional[str] = None): + """ + Verifies that the expected value is less than or equal to the actual value, raising a + StarkException with the appropriate code and message, where the expected and actual values are + added to the message. + """ + _stark_assert_not_symbol(exp_val, actual_val, symbol='>', code=code, message=message) + + +def stark_assert_lt(exp_val, actual_val, code, message: Optional[str] = None): + """ + Verifies that the expected value is strictly less than the actual value, raising a + StarkException with the appropriate code and message, where the expected and actual values are + added to the message. + """ + _stark_assert_not_symbol(exp_val, actual_val, symbol='>=', code=code, message=message) + + +def _stark_assert_not_symbol( + exp_val, actual_val, symbol: str, code, message: Optional[str] = None): + """ + Receives a symbol as a string that compares two values (e.g '==', '>') and verifies that: + `not exp_val symbol actual_val`. + + Example: + _stark_assert_not_symbol(3, 2, '==', code) -> Does nothing + _stark_assert_not_symbol(3, 3, '==', code) -> Raises an exception + + the given symbol must be mapped by the dict `symbol_to_function` to a function that performs the + symbol on two values. + """ + MIN_HEX_SIZE = 1 << 128 + + format_val = lambda val: hex(val) if isinstance(val, int) and val > MIN_HEX_SIZE else val + if symbol_to_function[symbol](exp_val, actual_val): + eq_log = f'{format_val(exp_val)} {symbol} {format_val(actual_val)}' + message = f'{message}\n{eq_log}' if message else eq_log + raise StarkException(code=code, message=message) diff --git a/src/starkware/starkware_utils/field_validators.py b/src/starkware/starkware_utils/field_validators.py new file mode 100644 index 00000000..24322c58 --- /dev/null +++ b/src/starkware/starkware_utils/field_validators.py @@ -0,0 +1,257 @@ +import os +from typing import Callable, Dict, Iterable, List, Optional, TypeVar, Union + +import marshmallow +import marshmallow.validate +from web3 import Web3 + +DNS_REGEX = r'^((?!-)[${}a-z0-9-]{1,63}(? ValidatorType: + error_message = 'Invalid {field_name}: {{input}}; must be a legal {regex_description}'.format( + field_name=field_name, regex_description=regex_description) + validate_regex = marshmallow.validate.Regexp(regex=regex, error=error_message) + + def validator(value): + if value is None: + if allow_none: + return True + + raise marshmallow.exceptions.ValidationError(message=error_message.format(input=value)) + + return validate_regex(value) + + return validator + + +def validate_dns(*, allow_none: bool) -> ValidatorType: + return validate_regex_match( + field_name='dns', regex=DNS_REGEX, allow_none=allow_none, regex_description='DNS label') + + +def validate_url( + *, url_name: str, schemes: marshmallow.types.StrSequenceOrSet, + require_full_url: bool) -> ValidatorType: + error_message = ( + 'Invalid {url_name} URL: {{input}}; ' + 'must be a legal URL starting with {schemes}').format( + url_name=url_name, schemes=','.join(schemes)) + return marshmallow.validate.URL( + schemes=schemes, require_tld=require_full_url, error=error_message) + + +validate_gateway_url = validate_url( + url_name='Gateway endpoint', schemes={'http', 'https'}, require_full_url=False) + +validate_internal_url = validate_url( + url_name='Internal Gateway endpoint', schemes={'http', 'https'}, require_full_url=False) + +validate_node_endpoint = validate_url( + url_name='Node endpoint', schemes={'http', 'https'}, require_full_url=False) + + +def validate_one_of( + field_name: str, *, choices: Iterable, allow_none: bool = False) -> ValidatorType: + error_message = 'Invalid {field_name}: {{input}}; allowed values: {{choices}}'.format( + field_name=field_name) + one_of_validator = marshmallow.validate.OneOf(choices=choices, error=error_message) + + def validator(value): + if allow_none and value is None: + return True + + return one_of_validator(value) + + return validator + + +def validate_equal(field_name: str, *, allowed_value: T) -> ValidatorType: + error_message = 'Invalid {field_name}: {{input}}; must be: {{other}}'.format( + field_name=field_name) + return marshmallow.validate.Equal(comparable=allowed_value, error=error_message) + + +def validate_in_range( + field_name, *, min_value: Optional[int] = None, max_value: Optional[int] = None, + min_inclusive: bool = True, max_inclusive: bool = True, + allow_none: bool = False, error_message: Optional[str] = None) -> ValidatorType: + if error_message is None: + range_string = ( + f'{"[" if min_inclusive else "("}' + f'{"-inf" if min_value is None else min_value},' + f'{"inf" if max_value is None else max_value}' + f'{"]" if max_inclusive else ")"}') + error_message = \ + 'Invalid {field_name}: {{input}}; must be in the range {range_string}'.format( + field_name=field_name, range_string=range_string) + + range_validator = marshmallow.validate.Range( + min=min_value, max=max_value, min_inclusive=min_inclusive, max_inclusive=max_inclusive, + error=error_message) + + def validator(value): + if allow_none and value is None: + return True + + return range_validator(value) + + return validator + + +def validate_positive(field_name: str, *, allow_none: bool = False) -> ValidatorType: + error_message = 'Invalid {field_name}: {{input}}; must be a positive value'.format( + field_name=field_name) + return validate_in_range( + field_name=field_name, min_value=0, min_inclusive=False, allow_none=allow_none, + error_message=error_message) + + +def validate_non_negative(field_name, *, allow_none=False): + error_message = 'Invalid {field_name}: {{input}}; must be a non-negative value'.format( + field_name=field_name) + return validate_in_range( + field_name=field_name, min_value=0, min_inclusive=True, allow_none=allow_none, + error_message=error_message) + + +def validate_positive_or_infinity(field_name: str) -> ValidatorType: + error_message = 'Invalid {field_name}: {{input}}; must be positive -1 (for unlimited)'.format( + field_name=field_name) + + def validator(value): + if value <= 0 and value != -1: + raise marshmallow.ValidationError(error_message.format(input=value)) + + return True + + return validator + + +def validate_probability(field_name: str, *, allow_none: bool = False) -> ValidatorType: + return validate_in_range( + field_name=field_name, min_value=0, max_value=1, min_inclusive=True, max_inclusive=True, + allow_none=allow_none) + + +def validate_public_key(field_name: str) -> ValidatorType: + error_message = 'Invalid {field_name}: {{input}}; must be a legal Ethereum address'.format( + field_name=field_name) + + address_regex = r'^0x[a-fA-F0-9]{40}$' + validate_address_regex = marshmallow.validate.Regexp(regex=address_regex, error=error_message) + + def validator(addresses: Union[str, List[str]]): + if isinstance(addresses, str): + addresses = [addresses] + + for address in addresses: + validate_address_regex(address) + + if not Web3.isChecksumAddress(address): + raise marshmallow.ValidationError(error_message.format(input=address)) + + return True + + return validator + + +def validate_private_key(field_name: str) -> ValidatorType: + error_message = 'Invalid {field_name}: {{input}}; must be a legal Ethereum private key'.format( + field_name=field_name) + + private_key_regex = r'^0x[a-fA-F0-9]{64}$' + return marshmallow.validate.Regexp(regex=private_key_regex, error=error_message) + + +def validate_customer_id(field_name: str) -> ValidatorType: + error_message = 'Invalid {field_name}: {{input}}; must be an alphanumeric string'.format( + field_name=field_name) + return marshmallow.validate.Regexp(regex=r'^[A-Za-z0-9_-]+$', error=error_message) + + +def validate_absolute_linux_path(field_name: str, *, allow_none: bool) -> ValidatorType: + error_message = 'Invalid {field_name}: {{input}}; must be a legal absolute Linux path'.format( + field_name=field_name) + + def validator(value: str): + if allow_none and value is None: + return True + + if not os.path.isabs(value): + raise marshmallow.ValidationError(error_message.format(input=value)) + + return True + + return validator + + +validate_certificates_path = validate_absolute_linux_path('certificates_path', allow_none=True) + + +def validate_communication_params(*, url: str, certificates_path: Optional[str]): + https_used = url.startswith('https') + certs_used = certificates_path is not None + + if certs_used and not https_used: + raise ValueError('Certificates should be used together with a HTTPS URL') + + +def validate_dict( + field_name: str, + *, key_validator: Optional[Callable[[str], Callable[[TypeKey], bool]]] = None, + value_validator: Optional[Callable[[str], Callable[[TypeValue], bool]]] = None, + allow_none: bool = False) -> Callable[[Dict[TypeKey, TypeValue]], bool]: + """ + Returns a validator for a dictionary, that validates the keys according to key_validator, and + the values according to value_validator. + These validators should be methods that get the field name, and return the validator for that + field. Set these validators to None to have empty validators, which will always return True. + """ + def validator(dictionary: Dict[TypeKey, TypeValue]): + nonlocal key_validator, value_validator + if allow_none and dictionary is None: + return True + if key_validator is None: + key_validator = (lambda name: lambda key: True) + if value_validator is None: + value_validator = (lambda name: lambda value: True) + for key, value in dictionary.items(): + try: + (key_validator(str(key)))(key) + (value_validator(str(key)))(value) + except Exception as e: + raise type(e)(f'Dictionary {field_name} is not valid: ' + str(e)) + return True + + return validator + + +def validate_power_of_two(field_name: str) -> ValidatorType: + """ + Return a validator for a number, that validates that the number is a power of 2. + """ + error_message = 'Invalid {field_name}: {{input}}; must be a power of 2'.format( + field_name=field_name) + + def validator(value): + tmp_value = value + while tmp_value > 1: + if tmp_value % 2 == 1: + break + tmp_value = tmp_value // 2 + + if tmp_value != 1: + raise marshmallow.ValidationError(error_message.format(input=value)) + return True + + return validator diff --git a/src/starkware/starkware_utils/marshmallow_dataclass_fields.py b/src/starkware/starkware_utils/marshmallow_dataclass_fields.py new file mode 100644 index 00000000..fdb19094 --- /dev/null +++ b/src/starkware/starkware_utils/marshmallow_dataclass_fields.py @@ -0,0 +1,164 @@ +import base64 +import re +from abc import ABC, abstractmethod + +import marshmallow.fields as mfields +from frozendict import frozendict +from marshmallow import ValidationError +from marshmallow.base import FieldABC + +from starkware.starkware_utils.custom_raising_dict import CustomRaisingDict, CustomRaisingFrozenDict + + +class IntAsStr(mfields.Field): + """ + A field that behaves like an integer, but serializes to a string. Some amount fields are + serialized to strings in the JSONs, so that JavaSscript can handle them (JavaScript cannot + handle uint64 numbers). + """ + + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + return str(value) + + def _deserialize(self, value, attr, data, **kwargs): + return int(value) + + +class EnumField(mfields.Field): + """ + A field that behaves like an enum, but serializes to a string. + """ + + def __init__(self, enum_cls, required: bool = False, allow_none: bool = False, **kwargs): + self.enum_cls = enum_cls + super().__init__(required=required, allow_none=allow_none, **kwargs) + + def _serialize(self, value, attr, obj, **kwargs): + if value is not None: + return value.name + + if self.allow_none: + # value is None and None is allowed. + return None + + raise ValidationError( + message=f'Field of type {type(self).__name__} is None, but allow_none=False') + + def _deserialize(self, value, attr, data, **kwargs): + # No need to handle the case in which value is None, since public deserialize() method + # takes care of that. + return self.enum_cls[value] + + +class IntAsHex(mfields.Field): + """ + A field that behaves like an integer, but serializes to a hex string. Usually, this applies to + field elements. + """ + + default_error_messages = {'invalid': 'Expected hex string, got: "{input}".'} + + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + assert isinstance(value, int) + return hex(value) + + def _deserialize(self, value, attr, data, **kwargs): + if re.match('^0x[0-9a-f]+$', value) is None: + self.fail('invalid', input=value) + + return int(value, 16) + + +class BytesAsHex(mfields.Field): + """ + A field that behaves like bytes, but serializes to a hex string. + """ + + default_error_messages = {'invalid': 'Expected hex string, got: "{input}".'} + + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + assert isinstance(value, bytes) + return value.hex() + + def _deserialize(self, value, attr, data, **kwargs): + if re.match('^[0-9a-f]*$', value) is None: + self.fail('invalid', input=value) + + return bytes.fromhex(value) + + +class BytesAsBase64Str(mfields.Field): + """ + A field that behaves like bytes, but serializes to base64. + """ + + default_error_messages = {'invalid': 'Expected Base64 bytes, got: "{input}".'} + + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + assert isinstance(value, bytes) + return base64.b64encode(value).decode('ascii') + + def _deserialize(self, value, attr, data, **kwargs): + return base64.b64decode(value.encode('ascii')) + + +class CustomField(ABC): + """ + A class representing a field deserialized into a variable of a specific type. + """ + + @property + @classmethod + @abstractmethod + def _type(cls) -> type: + pass + + @classmethod + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) # type: ignore[call-arg] + + assert issubclass(cls, FieldABC), \ + 'CustomField must be used along with inheritance from a marshmallow field.' + + def _deserialize(self, *args, **kwargs): + return self._type(super()._deserialize(*args, **kwargs)) # type: ignore + + +class SetField(CustomField, mfields.List): + _type = set + + +class VariadicLengthTupleField(CustomField, mfields.List): + _type = tuple + + +class FrozenDictField(CustomField, mfields.Mapping): + _type = frozendict + + +class CustomRaisingDictField(CustomField, mfields.Mapping): + _type = CustomRaisingDict + + +class CustomRaisingFrozenDictField(CustomField, mfields.Mapping): + _type = CustomRaisingFrozenDict + + +# Field metadata for general use in marshmallow dataclasses. + +def enum_field_metadata( + *, enum_class: type, require: bool = True, allow_none: bool = False) -> dict: + return dict( + marshmallow_field=EnumField(enum_cls=enum_class, required=require, allow_none=allow_none)) + + +boolean_field_metadata = dict(marshmallow_field=mfields.Boolean(truthy={True}, falsy={False})) +optional_field_metadata = dict(allow_none=True, missing=None) diff --git a/src/starkware/starkware_utils/serializable.py b/src/starkware/starkware_utils/serializable.py new file mode 100644 index 00000000..04c9f384 --- /dev/null +++ b/src/starkware/starkware_utils/serializable.py @@ -0,0 +1,116 @@ +import inspect +from abc import ABC, abstractmethod +from json import JSONDecoder, JSONEncoder +from typing import ClassVar, Dict, Type, TypeVar + +TSerializableObject = TypeVar('TSerializableObject', bound='Serializable') +TStrSerializableObject = TypeVar('TStrSerializableObject', bound='StringSerializable') + + +class Serializable(ABC): + """ + Base class to classes whose objects can be (de)serialized. + """ + + @abstractmethod + def serialize(self) -> bytes: + pass + + @classmethod + @abstractmethod + def deserialize(cls: Type[TSerializableObject], data: bytes) -> TSerializableObject: + pass + + +class StringSerializable(Serializable): + """ + A class that has dumps and loads functions to convert its objects to and from strings. + + Classes implementing the dumps (and loads) methods can be automatically encoded using + an extended JSON as follows: + + The get_encoder and get_decoder can be used to extend the JSON mechanism to encode and + decode StringSerializable. This extended JSON can process any object built from + 1. JSON-serializable objects: string, int, float, int or float derived enums, booleans, None + 2. StringSerializable class, + 3. Nested objects in lists, dictionaries and tuples. + + In order to be able to use this extended JSON to encode a StringSerializable class CLS, + you need to: + 1. Make sure the class CLS is defined before using json, + 2. create an encoder=StringSerializable.get_encoder() + 3. use it in JSON by json.dumps(object_to_dump, cls=self.encoder) + Similarly, in the decoder use instead: + 4. json.loads(encoded_string, cls=self.decoder) + + Remarks: + 1. Note that in order to encode and decode a class CLS, it must be defined before you create + the encoder and the decoder. + 2. If a class B extends A which implements the dumps method, then when using this extended + JSON mechanism, object of type B will be encoded and decoded as type A objects. + If B wants to use a different loads function (even if it has the same dumps as A), + then it needs to implement the dumps function. + """ + + _classes: ClassVar[Dict[str, Type['StringSerializable']]] = {} + _serialize_name: ClassVar[str] + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) # type: ignore[call-arg] + + # Look for closest parent that implements the dumps method. Use this class when + # dumping and loading. + for mro_cls in inspect.getmro(cls): + if mro_cls is StringSerializable: + # The dumps method is abstract. + continue + if 'dumps' in mro_cls.__dict__: + cls._serialize_name = f'{mro_cls}' + StringSerializable._classes[cls._serialize_name] = cls + + @abstractmethod + def dumps(self) -> str: + pass + + @classmethod + @abstractmethod + def loads(cls: Type[TStrSerializableObject], data: str) -> TStrSerializableObject: + pass + + def serialize(self) -> bytes: + return self.dumps().encode('ascii') + + @classmethod + def deserialize(cls: Type[TStrSerializableObject], data: bytes) -> TStrSerializableObject: + return cls.loads(data=data.decode('ascii')) + + class SerializableEncoder(JSONEncoder): + def default(self, obj): + if isinstance(obj, StringSerializable): + if obj._serialize_name in StringSerializable._classes: + return { + '_serializable': obj._serialize_name, + 'value': obj.dumps() + } + + return JSONEncoder.default(self, obj) + + @staticmethod + def get_encoder() -> Type[JSONEncoder]: + return StringSerializable.SerializableEncoder + + class SerializableDecoder(JSONDecoder): + def __init__(self, *args, **kwargs): + super().__init__(object_hook=self.object_hook, *args, **kwargs) + + def object_hook(self, obj): + if '_serializable' not in obj: + return obj + cls_repr = obj['_serializable'] + serialized_class = StringSerializable._classes.get(cls_repr, None) + assert serialized_class is not None, f'Could not decode the class {cls_repr}.' + return serialized_class.loads(data=obj['value']) + + @staticmethod + def get_decoder() -> Type[JSONDecoder]: + return StringSerializable.SerializableDecoder diff --git a/src/starkware/starkware_utils/validated_dataclass.py b/src/starkware/starkware_utils/validated_dataclass.py new file mode 100644 index 00000000..d169f2a5 --- /dev/null +++ b/src/starkware/starkware_utils/validated_dataclass.py @@ -0,0 +1,331 @@ +import dataclasses +import inspect +import random +import re +from typing import Any, ClassVar, Dict, Optional, Sequence, Tuple, Type, TypeVar + +import marshmallow +import marshmallow.fields as mfields +import marshmallow_dataclass +import typeguard + +from starkware.starkware_utils.serializable import StringSerializable +from starkware.starkware_utils.validated_fields import Field + +TValidatedDataclass = TypeVar('TValidatedDataclass', bound='ValidatedDataclass') +TSerializableDataclass = TypeVar('TSerializableDataclass', bound='SerializableMarshmallowDataclass') +T = TypeVar('T') + + +def camel_to_snake_case(camel_case_name: str) -> str: + """ + Converts a name with Capital first letters to lower case with '_' as separators. + For example, CamelToSnakeCase -> camel_to_snake_case. + """ + return (camel_case_name[0] + re.sub(r'([A-Z])', r'_\1', camel_case_name[1:])).lower() + + +class SerializableMarshmallowDataclass(StringSerializable): + """ + Base class to classes decorated with marshmallow_dataclass.dataclass, implementing the + Serializable interface. + """ + + class_name_prefix: ClassVar[bytes] + Schema: ClassVar[Type[marshmallow.Schema]] + + @classmethod + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) # type: ignore[call-arg] + + cls.class_name_prefix = camel_to_snake_case(camel_case_name=cls.__name__).encode('ascii') + + def dump(self) -> dict: + return self.Schema().dump(obj=self) + + @classmethod + def load(cls: Type[TSerializableDataclass], data: dict) -> TSerializableDataclass: + return cls.Schema().load(data=data) + + def dumps(self) -> str: + return self.Schema().dumps(obj=self) + + @classmethod + def loads(cls: Type[TSerializableDataclass], data: str) -> TSerializableDataclass: + return cls.Schema().loads(json_data=data) + + @classmethod + def prefix(cls) -> bytes: + """ + Converts the class name to a lower case name with '_' as separators and returns the + bytes version of this name. For example HelloWorldAB -> b'hello_world_a_b'. + """ + return cls.class_name_prefix + + +class ValidatedDataclass: + """ + A class containing a type- and value-level validation. + """ + + def __post_init__(self): + self.validate_dataclass() + + def validate_dataclass(self): + self.validate_types() + self.validate_values() + + @classmethod + def get_random_element( + cls: Type[TValidatedDataclass], + random_object: Optional[random.Random] = None, **data) -> TValidatedDataclass: + """ + Generates a random object of the given class restricted by the given data. + Any field can be either passed as an argument (field_name=field_value), and if not, + it is generated randomly. + The random generation is done via the validated_field inside the metadata, or if there + is no such and the field is a ValidatedMarshmallow class, it recursively uses + get_random_element. + + Example usage: + @marshmallow_dataclasses.dataclass + class Inner(ValidatedMarshmallowDataclass): + a: int = field(validated_field=...) + b: int = field(validated_field=...) + + @marshmallow_dataclasses.dataclass + class Outer(ValidatedMarshmallowDataclass): + c: int = field(validated_field=...) + d: int = field(validated_field=...) + inner: Inner + + Outer.get_random_element(c=5) # Randomize a, b and d. + """ + new_object_data = {} + for field in dataclasses.fields(cls): + # Fields with a value from the arguments. + if field.name in data.keys(): + new_object_data[field.name] = data[field.name] + continue + + # Fields without a value from the arguments. + validated_field = get_validated_field(field=field) + if validated_field is not None: + new_object_data[field.name] = validated_field.get_random_value( + random_object=random_object) + continue + + # The field is a validated class object. + is_validated_dataclass = ( + inspect.isclass(field.type) and + issubclass(field.type, ValidatedMarshmallowDataclass)) + if is_validated_dataclass: + new_object_data[field.name] = field.type.get_random_element( + random_object=random_object) + continue + + raise Exception( + f'Could not randomize the field {field.name} in an object of type {cls}.') + + return cls(**new_object_data) # type: ignore + + def validate_values(self): + for field in dataclasses.fields(self): + metadata = getattr(field, 'metadata', None) + if metadata is None: + continue + + value = getattr(self, field.name) + # First use the field_validated argument, and only if it does not exist, + # use the validation inside the marshmallow field argument. + validated_field = metadata.get('validated_field', None) + if validated_field is None: + marshmallow_field = field.metadata.get('marshmallow_field', None) + if marshmallow_field is not None: + validate_field(field=marshmallow_field, value=value) + else: + name_in_messages = metadata.get('name_in_messages', None) + validated_field.validate(value=value, name=name_in_messages) + + def validate_types(self): + for field in dataclasses.fields(self): + typeguard.check_type( + argname=field.name, value=getattr(self, field.name), expected_type=field.type) + + +class ValidatedMarshmallowDataclass(ValidatedDataclass, SerializableMarshmallowDataclass): + """ + Base class to classes decorated with marshmallow_dataclass.dataclass, containing validations. + """ + + +def get_validated_field(field: dataclasses.Field) -> Optional[Field]: + """ + Checks if the dataclass field has a validated_field attribute in its metadata. + If so returns it, otherwise returns None. + """ + if field.metadata is not None and 'validated_field' in field.metadata: + return field.metadata['validated_field'] + return None + + +def late_marshmallow_dataclass(cls: Optional[type] = None, **kwargs): + """ + A helper function for creating marshmallow dataclasses while inheriting fields from base class. + + Example usage: + class Base: + x: T + y: int = 5 + + @marshmallow_dataclasses.dataclass + class Child(Base): + x: str + # y: int = 5 will be inherited from parent, due to late_marshmallow_dataclass. + + Note that no parent class of the annotated class should be a dataclass. + In case that a nondefault attribute follows a default attribute, it is not guaranteed that the + derived class construction will work as expected. + """ + if cls is None: # Arguments passed directly to decorator. + def inner(cls): + prepare_class_annotations_and_attribute_values(cls) + return marshmallow_dataclass.dataclass(cls, **kwargs) + + return inner + + prepare_class_annotations_and_attribute_values(cls) + return marshmallow_dataclass.dataclass(cls) + + +def prepare_class_annotations_and_attribute_values(cls): + """ + Prepares class annotations in the following manner: + Annotations are added to __annotations__ dictionary in the reverse MRO order. Members with + default values are added last, in order for them to appear last in the auto-generated __init__ + signature. + In addition, sets values for attributes in cls.__dict__. + """ + annotations, attr_values = process_class_annotations_and_attribute_values(cls=cls) + set_class_annotations_and_attribute_values( + cls=cls, annotations=annotations, attr_values=attr_values) + + +def process_class_annotations_and_attribute_values(cls) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Returns class attributes annotations and values. + The annotations and values are taken from the first class the attribute appears in its + annotations, in the cls' MRO order. + """ + annotations: Dict[str, Any] = {} + attr_values: Dict[str, Any] = {} + + for base_cls in inspect.getmro(cls): + if '__annotations__' not in base_cls.__dict__: + continue + + for name in base_cls.__annotations__: + if name in annotations: + # Attribute already seen in a derived class. + continue + + if name in base_cls.__dict__: + attr_values[name] = base_cls.__dict__[name] + continue + + if ('__dataclass_fields__' in base_cls.__dict__ and + name in base_cls.__dict__['__dataclass_fields__']): + # cls is a dataclass, in which all fields appear in cls.__dataclass_fields__, + # rather than directly in cls.__dict__. + attr_values[name] = base_cls.__dict__['__dataclass_fields__'][name] + continue + + # Prepand annotations, so that they appear in reverse MRO order. + annotations = {**base_cls.__annotations__, **annotations} + + return annotations, attr_values + + +def set_class_annotations_and_attribute_values( + cls, annotations: Dict[str, Any], attr_values: Dict[str, Any]): + """ + Sets given attributes to cls.__dict__ and its annotations. + The annotations will contain the given annotations, where the attributes with default values + will appear last. + """ + # Make sure the attributes appear directly in cls.__dict__ as well. + default_value_annotations: Dict[str, Any] = {} + for name, attr_value in attr_values.items(): + setattr(cls, name, attr_value) + + if has_default_value(cls=cls, attr_value=attr_value): + default_value_annotations[name] = annotations[name] + + # Locate members with default values in the end of the annotations dictionary. + cls.__annotations__ = { + name: annotation for name, annotation in annotations.items() + if name not in default_value_annotations} + cls.__annotations__.update(default_value_annotations) + + +def has_default_value(cls, attr_value: Any) -> bool: + """ + Returns whether the class member has a default value or not. + """ + if not isinstance(attr_value, dataclasses.Field): + """ + Plain default value assignment: + class A: + x: int = 1 + """ + return True + + # If member does not appear in __init__'s signature, having a default value is irrelevant. + return ( + attr_value.init and + attr_value.default is not dataclasses.MISSING or + # Mypy has a problem with object members that are callables (it sees access to them as + # passing self). This is actually originated in dataclasses' annotations in typeshed, since + # the source code has no annotations. + # See https://github.com/python/mypy/issues/6910 for details on this problem. + attr_value.default_factory is not dataclasses.MISSING) # type: ignore + + +# Validators for private use in this file. + +def validate_value(*, field: mfields.Field, value: Any): + """ + Invokes the field's validator, if exists and it is callable. + Note: multiple validators are not currently supported as an iterable, but rather as a single + validation function that and-s between the validators' results. + """ + if field.validate is not None and callable(field.validate): + field.validate(value) + + +def validate_field(field: mfields.Field, value: Any): + validate_value(field=field, value=value) + + # Validate inner elements, if field is a container. + if isinstance(field, mfields.List): + validate_list(field, value) + elif isinstance(field, mfields.Mapping): + if field.key_field is not None: + validate_list(mfields.List(field.key_field), value.keys()) + if field.value_field is not None: + validate_list(mfields.List(field.value_field), value.values()) + + +def validate_list(list_field: mfields.List, list_value: Sequence): + if not isinstance(list_field.inner, mfields.Field): + # Nothing to check further, since it is not a marshmallow field. + return + + if list_value is None: + if list_field.allow_none: + return + + raise marshmallow.ValidationError('Field may not be None.') + + for inner_element in list_value: + validate_field(field=list_field.inner, value=inner_element) diff --git a/src/starkware/starkware_utils/validated_fields.py b/src/starkware/starkware_utils/validated_fields.py new file mode 100644 index 00000000..578340a3 --- /dev/null +++ b/src/starkware/starkware_utils/validated_fields.py @@ -0,0 +1,206 @@ +import dataclasses +import random +from abc import ABC, abstractmethod +from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Type, TypeVar + +import marshmallow.fields as mfields + +from starkware.python.utils import initialize_random +from starkware.starkware_utils.error_handling import ErrorCode, stark_assert +from starkware.starkware_utils.field_validators import validate_in_range +from starkware.starkware_utils.marshmallow_dataclass_fields import ( + BytesAsBase64Str, BytesAsHex, IntAsHex, IntAsStr) + +T = TypeVar('T') + + +class Field(ABC, Generic[T]): + """ + A class representing data types for fields in ValidatedMarshmallowDataclass. + A dataclass field using this should have the following in its metadata: + 1. Data needed for @dataclasses.dataclass fields: 'description', 'default', + 'default_factory', etc. , + 2. Data needed for marshmallow: in 'marshmallow_field' , + 3. An object implementing this Field class: in 'validated_field' , + 4. A name for messages: in 'name_in_messages'. + """ + + @property + @abstractmethod + def name(self) -> str: + """ + The default name that appears in messages. + """ + + @abstractmethod + def format(self, value) -> str: + """ + The formatted value that appears in messages. + """ + + # Randomization. + + @abstractmethod + def get_random_value(self, random_object: Optional[random.Random] = None) -> T: + """ + Returns a random valid value for this field. + """ + + # Validation. + + @abstractmethod + def is_valid(self, value: T) -> bool: + """ + Checks and returns if the given value is valid. + """ + + def validate(self, value: T, name: Optional[str] = None): + error_message = self.format_invalid_value_error_message(value=value, name=name) + stark_assert(self.is_valid(value=value), code=self.error_code, message=error_message) + + @property + @abstractmethod + def error_code(self) -> ErrorCode: + """ + The error codes that is returned if the value is not valid. + """ + + @abstractmethod + def get_invalid_values(self) -> List[T]: + """ + Returns a list of invalid values for this field. + """ + + @abstractmethod + def format_invalid_value_error_message(self, value: T, name: Optional[str] = None) -> str: + """ + Constructs the error message for invalid values. + """ + + # Serialization. + @abstractmethod + def get_marshmallow_field(self) -> mfields.Field: + """ + Returns a marshmallow field that serializes and deserializes values of this field. + """ + + # Metadata. + + def metadata(self, field_name: Optional[str] = None): + """ + Creates the metadata associated with this field. If provided, then use the given field_name + for messages, and otherwise (if it is None) use the default name. + """ + return dict( + marshmallow_field=self.get_marshmallow_field(), + validated_field=self, + name_in_messages=self.name if field_name is None else field_name) + + +@dataclasses.dataclass(frozen=True) +class RangeValidatedField(Field[int]): + """ + Represents a range-validated integer field. + """ + + lower_bound: int # Inclusive. + upper_bound: int # Non-inclusive. + name_in_error_message: str + out_of_range_error_code: ErrorCode + formatter: Optional[Callable[[int], str]] = None + out_of_range_message: ClassVar[str] = '{field_name} {field_value} is out of range' + + @property + def name(self): + return self.name_in_error_message + + def format(self, value: int) -> str: + return self._format_value(value=value) + + def get_random_value(self, random_object: Optional[random.Random] = None) -> int: + r = initialize_random(random_object) + return r.randrange(self.lower_bound, self.upper_bound) + + def is_valid(self, value: int) -> bool: + return self.lower_bound <= value < self.upper_bound + + def format_invalid_value_error_message(self, value: int, name: Optional[str] = None) -> str: + return self.out_of_range_message.format( + field_name=self.name if name is None else name, + field_value=self._format_value(value)) + + @property + def error_code(self) -> ErrorCode: + return self.out_of_range_error_code + + def get_invalid_values(self) -> List[int]: + return [self.lower_bound - 1, self.upper_bound] + + def _format_value(self, value: int) -> str: + if self.formatter is None: + return str(value) + return self.formatter(value) + + def get_marshmallow_field(self) -> mfields.Field: + if self.formatter == hex: + return IntAsHex(required=True) + if self.formatter == str: + return IntAsStr(required=True) + if self.formatter is None: + return mfields.Integer(required=True) + raise NotImplementedError( + f'{self.name}: The given formatter {self.formatter.__name__} ' + 'does not have a suitable metadata.') + + +# Field metadata utilities. + +def _generate_metadata( + marshmallow_field_cls: Type[mfields.Field], validated_field: Optional[Field], + required: Optional[bool] = None) -> Dict[str, Any]: + if required is None: + required = True + + metadata: Dict[str, Any] = dict(marshmallow_field=marshmallow_field_cls(required=required)) + if validated_field is not None: + metadata.update(validated_field=validated_field) + + return metadata + + +def int_metadata( + validated_field: Optional[Field], required: Optional[bool] = None) -> Dict[str, Any]: + return _generate_metadata( + marshmallow_field_cls=mfields.Integer, validated_field=validated_field, required=required) + + +def int_as_hex_metadata( + validated_field: Optional[Field], required: Optional[bool] = None) -> Dict[str, Any]: + return _generate_metadata( + marshmallow_field_cls=IntAsHex, validated_field=validated_field, required=required) + + +def int_as_str_metadata( + validated_field: Optional[Field], required: Optional[bool] = None) -> Dict[str, Any]: + return _generate_metadata( + marshmallow_field_cls=IntAsStr, validated_field=validated_field, required=required) + + +def bytes_as_hex_metadata( + validated_field: Optional[Field], required: Optional[bool] = None) -> Dict[str, Any]: + return _generate_metadata( + marshmallow_field_cls=BytesAsHex, validated_field=validated_field, required=required) + + +def bytes_as_base64_str_metadata( + validated_field: Optional[Field], required: Optional[bool] = None) -> Dict[str, Any]: + return _generate_metadata( + marshmallow_field_cls=BytesAsBase64Str, validated_field=validated_field, required=required) + + +def sequential_id_metadata( + field_name: str, required: bool = True, + allow_previous_id: bool = False) -> Dict[str, Any]: + validator = validate_in_range(field_name=field_name, min_value=-1 if allow_previous_id else 0) + return dict( + marshmallow_field=mfields.Integer(strict=True, required=required, validate=validator))