From f9ac4c10baa4c567e28d368096734c943b91dd53 Mon Sep 17 00:00:00 2001 From: Lior Goldberg Date: Mon, 28 Dec 2020 10:42:28 +0200 Subject: [PATCH] Cairo v0.0.1. --- .gitignore | 3 + .gitmodules | 3 + CMakeLists.txt | 24 + Dockerfile | 22 + LICENSE.txt | 15 + README.md | 57 + build.sh | 15 + licenses/CairoProgramLicense.txt | 37 + licenses/CairoToolchainLicense.txt | 35 + repos/CMakeLists.txt | 1 + repos/starkware-public | 1 + scripts/requirements-deps.json | 357 +++ scripts/requirements-gen.txt | 12 + scripts/requirements.txt | 27 + src/CMakeLists.txt | 1 + src/starkware/CMakeLists.txt | 2 + src/starkware/cairo/CMakeLists.txt | 4 + src/starkware/cairo/apps/CMakeLists.txt | 1 + .../cairo/apps/starkex2_0/CMakeLists.txt | 40 + .../starkex2_0/common/cairo_builtins.cairo | 27 + .../cairo/apps/starkex2_0/common/dict.cairo | 219 ++ .../common/merkle_multi_update.cairo | 190 ++ .../starkex2_0/common/merkle_update.cairo | 84 + .../apps/starkex2_0/common/registers.cairo | 29 + .../cairo/apps/starkex2_0/dex_constants.cairo | 36 + .../cairo/apps/starkex2_0/dex_context.cairo | 22 + .../cairo/apps/starkex2_0/execute_batch.cairo | 127 ++ .../execute_false_full_withdrawal.cairo | 59 + .../apps/starkex2_0/execute_limit_order.cairo | 145 ++ .../starkex2_0/execute_modification.cairo | 99 + .../apps/starkex2_0/execute_settlement.cairo | 74 + .../apps/starkex2_0/execute_transfer.cairo | 137 ++ .../apps/starkex2_0/hash_vault_ptr_dict.cairo | 60 + .../cairo/apps/starkex2_0/main.cairo | 137 ++ .../starkex2_0/starkex2_0_program_test.py | 40 + .../cairo/apps/starkex2_0/vault_update.cairo | 98 + .../apps/starkex2_0/verify_order_id.cairo | 50 + .../starkex2_0/verify_order_signature.cairo | 108 + src/starkware/cairo/bootloader/CMakeLists.txt | 35 + .../cairo/bootloader/hash_program.py | 38 + .../cairo/bootloader/hash_program_test.py | 8 + src/starkware/cairo/common/CMakeLists.txt | 25 + src/starkware/cairo/common/alloc.cairo | 6 + .../cairo/common/cairo_builtins.cairo | 22 + src/starkware/cairo/common/dict.cairo | 258 +++ src/starkware/cairo/common/dict.py | 58 + src/starkware/cairo/common/find_element.cairo | 99 + src/starkware/cairo/common/hash.cairo | 16 + src/starkware/cairo/common/hash_chain.cairo | 51 + src/starkware/cairo/common/hash_chain.py | 12 + src/starkware/cairo/common/hash_state.cairo | 77 + src/starkware/cairo/common/math.cairo | 262 +++ src/starkware/cairo/common/math_utils.py | 17 + src/starkware/cairo/common/memcpy.cairo | 38 + .../cairo/common/merkle_multi_update.cairo | 185 ++ .../cairo/common/merkle_update.cairo | 81 + src/starkware/cairo/common/registers.cairo | 49 + src/starkware/cairo/common/serialize.cairo | 48 + src/starkware/cairo/common/signature.cairo | 15 + src/starkware/cairo/lang/CMakeLists.txt | 25 + src/starkware/cairo/lang/MANIFEST.in | 1 + src/starkware/cairo/lang/__init__.py | 0 .../cairo/lang/builtins/CMakeLists.txt | 47 + .../builtins/builtin_runner_test_utils.py | 17 + .../checkpoints/checkpoints_builtin_runner.py | 39 + .../lang/builtins/checkpoints/instance_def.py | 8 + .../lang/builtins/hash/hash_builtin_runner.py | 81 + .../cairo/lang/builtins/hash/instance_def.py | 20 + .../lang/builtins/range_check/instance_def.py | 11 + .../range_check/range_check_builtin_runner.py | 84 + .../range_check_builtin_runner_test.py | 25 + .../lang/builtins/signature/instance_def.py | 15 + .../signature/signature_builtin_runner.py | 107 + .../signature_builtin_runner_test.py | 153 ++ .../cairo/lang/compiler/CMakeLists.txt | 132 ++ src/starkware/cairo/lang/compiler/__init__.py | 0 .../cairo/lang/compiler/assembler.py | 49 + .../cairo/lang/compiler/assembler_test.py | 53 + .../cairo/lang/compiler/ast/__init__.py | 0 .../cairo/lang/compiler/ast/arguments.py | 23 + .../cairo/lang/compiler/ast/bool_expr.py | 22 + .../cairo/lang/compiler/ast/cairo_types.py | 69 + .../cairo/lang/compiler/ast/code_elements.py | 573 +++++ src/starkware/cairo/lang/compiler/ast/expr.py | 303 +++ .../lang/compiler/ast/formatting_utils.py | 147 ++ .../compiler/ast/formatting_utils_test.py | 82 + .../cairo/lang/compiler/ast/instructions.py | 171 ++ .../cairo/lang/compiler/ast/module.py | 34 + src/starkware/cairo/lang/compiler/ast/node.py | 10 + .../cairo/lang/compiler/ast/notes.py | 61 + .../cairo/lang/compiler/ast/rvalue.py | 120 ++ .../cairo/lang/compiler/ast/types.py | 56 + .../cairo/lang/compiler/ast/visitor.py | 61 + .../cairo/lang/compiler/ast_objects_test.py | 483 +++++ src/starkware/cairo/lang/compiler/cairo.ebnf | 128 ++ .../cairo/lang/compiler/cairo_compile.py | 249 +++ .../cairo/lang/compiler/cairo_compile_test.py | 55 + .../cairo/lang/compiler/cairo_format.py | 42 + .../cairo/lang/compiler/const_expr_checker.py | 49 + .../cairo/lang/compiler/constants.py | 6 + .../cairo/lang/compiler/debug_info.py | 49 + src/starkware/cairo/lang/compiler/encode.py | 224 ++ .../cairo/lang/compiler/encode_test.py | 152 ++ .../cairo/lang/compiler/error_handling.py | 128 ++ .../lang/compiler/error_handling_test.py | 35 + .../lang/compiler/expression_evaluator.py | 51 + .../compiler/expression_evaluator_test.py | 33 + .../lang/compiler/expression_simplifier.py | 137 ++ .../compiler/expression_simplifier_test.py | 75 + .../lang/compiler/expression_transformer.py | 81 + src/starkware/cairo/lang/compiler/fields.py | 38 + .../lang/compiler/identifier_definition.py | 196 ++ .../compiler/identifier_definition_test.py | 83 + .../cairo/lang/compiler/identifier_manager.py | 307 +++ .../lang/compiler/identifier_manager_field.py | 28 + .../compiler/identifier_manager_field_test.py | 31 + .../lang/compiler/identifier_manager_test.py | 150 ++ .../cairo/lang/compiler/identifier_utils.py | 47 + .../lang/compiler/identifier_utils_test.py | 27 + .../cairo/lang/compiler/import_loader.py | 121 ++ .../cairo/lang/compiler/import_loader_test.py | 180 ++ .../cairo/lang/compiler/instruction.py | 105 + .../lang/compiler/instruction_builder.py | 492 +++++ .../lang/compiler/instruction_builder_test.py | 545 +++++ .../cairo/lang/compiler/instruction_test.py | 20 + .../cairo/lang/compiler/location_utils.py | 16 + .../cairo/lang/compiler/module_reader.py | 67 + .../cairo/lang/compiler/module_reader_test.py | 24 + src/starkware/cairo/lang/compiler/parser.py | 167 ++ .../cairo/lang/compiler/parser_errors_test.py | 117 + .../cairo/lang/compiler/parser_test.py | 497 +++++ .../cairo/lang/compiler/parser_transformer.py | 507 +++++ .../lang/compiler/preprocessor/__init__.py | 0 .../preprocessor/compound_expressions.py | 215 ++ .../preprocessor/compound_expressions_test.py | 346 +++ .../lang/compiler/preprocessor/conftest.py | 3 + .../cairo/lang/compiler/preprocessor/flow.py | 235 ++ .../lang/compiler/preprocessor/flow_test.py | 135 ++ .../preprocessor/identifier_collector.py | 198 ++ .../preprocessor/identifier_collector_test.py | 125 ++ .../compiler/preprocessor/local_variables.py | 206 ++ .../preprocessor/local_variables_test.py | 140 ++ .../compiler/preprocessor/preprocessor.py | 1358 ++++++++++++ .../preprocessor/preprocessor_error.py | 5 + .../preprocessor/preprocessor_test.py | 1912 +++++++++++++++++ .../preprocessor/preprocessor_test_utils.py | 25 + .../preprocessor/preprocessor_utils.py | 12 + .../compiler/preprocessor/reg_tracking.py | 128 ++ .../preprocessor/reg_tracking_test.py | 55 + src/starkware/cairo/lang/compiler/program.py | 121 ++ .../cairo/lang/compiler/references.py | 101 + .../cairo/lang/compiler/references_test.py | 23 + .../cairo/lang/compiler/scoped_name.py | 63 + .../cairo/lang/compiler/scoped_name_test.py | 39 + .../lang/compiler/substitute_identifiers.py | 27 + .../cairo/lang/compiler/test_utils.py | 6 + .../lang/compiler/type_system_visitor.py | 207 ++ .../lang/compiler/type_system_visitor_test.py | 98 + .../cairo/lang/ide/vim/ftdetect/cairo.vim | 3 + .../cairo/lang/ide/vim/ftplugin/cairo.vim | 6 + .../cairo/lang/ide/vim/syntax/cairo.vim | 30 + .../cairo/lang/ide/vscode-cairo/.gitignore | 4 + .../cairo/lang/ide/vscode-cairo/README.md | 31 + .../vscode-cairo/language-configuration.json | 19 + .../cairo/lang/ide/vscode-cairo/package.json | 76 + .../lang/ide/vscode-cairo/src/extension.ts | 55 + .../syntaxes/cairo.tmLanguage.json | 159 ++ .../cairo/lang/ide/vscode-cairo/tsconfig.json | 23 + src/starkware/cairo/lang/instances.py | 88 + .../cairo/lang/scripts/CMakeLists.txt | 7 + .../cairo/lang/scripts/cairo-compile | 10 + src/starkware/cairo/lang/scripts/cairo-format | 10 + src/starkware/cairo/lang/scripts/cairo-run | 10 + src/starkware/cairo/lang/setup.py | 29 + .../cairo/lang/tracer/CMakeLists.txt | 29 + src/starkware/cairo/lang/tracer/favicon.png | Bin 0 -> 9015 bytes src/starkware/cairo/lang/tracer/index.html | 55 + src/starkware/cairo/lang/tracer/tracer.css | 81 + src/starkware/cairo/lang/tracer/tracer.js | 379 ++++ src/starkware/cairo/lang/tracer/tracer.py | 125 ++ .../cairo/lang/tracer/tracer_data.py | 283 +++ .../cairo/lang/tracer/tracer_data_test.py | 86 + src/starkware/cairo/lang/vm/CMakeLists.txt | 96 + src/starkware/cairo/lang/vm/__init__.py | 0 .../cairo/lang/vm/air_public_input.py | 56 + src/starkware/cairo/lang/vm/builtin_runner.py | 185 ++ src/starkware/cairo/lang/vm/cairo_pie.py | 141 ++ src/starkware/cairo/lang/vm/cairo_pie_test.py | 50 + src/starkware/cairo/lang/vm/cairo_run.py | 337 +++ src/starkware/cairo/lang/vm/cairo_runner.py | 532 +++++ .../cairo/lang/vm/cairo_runner_test.py | 131 ++ src/starkware/cairo/lang/vm/crypto.py | 8 + src/starkware/cairo/lang/vm/memory_dict.py | 90 + .../cairo/lang/vm/memory_dict_test.py | 64 + .../cairo/lang/vm/memory_segments.py | 134 ++ .../cairo/lang/vm/memory_segments_test.py | 54 + .../cairo/lang/vm/output_builtin_runner.py | 151 ++ .../lang/vm/output_builtin_runner_test.py | 86 + src/starkware/cairo/lang/vm/relocatable.py | 136 ++ .../cairo/lang/vm/relocatable_fields.py | 37 + .../cairo/lang/vm/relocatable_fields_test.py | 29 + .../cairo/lang/vm/relocatable_test.py | 63 + src/starkware/cairo/lang/vm/security.py | 58 + src/starkware/cairo/lang/vm/security_test.py | 109 + src/starkware/cairo/lang/vm/test.cairo | 17 + src/starkware/cairo/lang/vm/trace_entry.py | 63 + .../cairo/lang/vm/trace_entry_test.py | 17 + src/starkware/cairo/lang/vm/utils.py | 37 + .../cairo/lang/vm/validated_memory_dict.py | 50 + .../lang/vm/validated_memory_dict_test.py | 52 + src/starkware/cairo/lang/vm/vm.py | 762 +++++++ src/starkware/cairo/lang/vm/vm_consts.py | 223 ++ src/starkware/cairo/lang/vm/vm_consts_test.py | 337 +++ src/starkware/cairo/lang/vm/vm_test.py | 576 +++++ src/starkware/python/CMakeLists.txt | 43 + src/starkware/python/__init__.py | 0 src/starkware/python/expression_string.py | 132 ++ .../python/expression_string_test.py | 40 + src/starkware/python/math_utils.py | 66 + src/starkware/python/math_utils_test.py | 64 + src/starkware/python/python_dependencies.py | 46 + src/starkware/python/test_utils.py | 33 + src/starkware/python/test_utils_test.py | 18 + src/starkware/python/utils.py | 129 ++ src/starkware/python/utils_test.py | 26 + 225 files changed, 24798 insertions(+) create mode 100644 .gitignore create mode 100644 .gitmodules create mode 100644 CMakeLists.txt create mode 100644 Dockerfile create mode 100644 LICENSE.txt create mode 100644 README.md create mode 100755 build.sh create mode 100644 licenses/CairoProgramLicense.txt create mode 100644 licenses/CairoToolchainLicense.txt create mode 100644 repos/CMakeLists.txt create mode 160000 repos/starkware-public create mode 100644 scripts/requirements-deps.json create mode 100644 scripts/requirements-gen.txt create mode 100644 scripts/requirements.txt create mode 100644 src/CMakeLists.txt create mode 100644 src/starkware/CMakeLists.txt create mode 100644 src/starkware/cairo/CMakeLists.txt create mode 100644 src/starkware/cairo/apps/CMakeLists.txt create mode 100644 src/starkware/cairo/apps/starkex2_0/CMakeLists.txt create mode 100644 src/starkware/cairo/apps/starkex2_0/common/cairo_builtins.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/common/dict.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/common/merkle_multi_update.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/common/merkle_update.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/common/registers.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/dex_constants.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/dex_context.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/execute_batch.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/execute_false_full_withdrawal.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/execute_limit_order.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/execute_modification.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/execute_settlement.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/execute_transfer.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/hash_vault_ptr_dict.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/main.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/starkex2_0_program_test.py create mode 100644 src/starkware/cairo/apps/starkex2_0/vault_update.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/verify_order_id.cairo create mode 100644 src/starkware/cairo/apps/starkex2_0/verify_order_signature.cairo create mode 100644 src/starkware/cairo/bootloader/CMakeLists.txt create mode 100644 src/starkware/cairo/bootloader/hash_program.py create mode 100644 src/starkware/cairo/bootloader/hash_program_test.py create mode 100644 src/starkware/cairo/common/CMakeLists.txt create mode 100644 src/starkware/cairo/common/alloc.cairo create mode 100644 src/starkware/cairo/common/cairo_builtins.cairo create mode 100644 src/starkware/cairo/common/dict.cairo create mode 100644 src/starkware/cairo/common/dict.py create mode 100644 src/starkware/cairo/common/find_element.cairo create mode 100644 src/starkware/cairo/common/hash.cairo create mode 100644 src/starkware/cairo/common/hash_chain.cairo create mode 100644 src/starkware/cairo/common/hash_chain.py create mode 100644 src/starkware/cairo/common/hash_state.cairo create mode 100644 src/starkware/cairo/common/math.cairo create mode 100644 src/starkware/cairo/common/math_utils.py create mode 100644 src/starkware/cairo/common/memcpy.cairo create mode 100644 src/starkware/cairo/common/merkle_multi_update.cairo create mode 100644 src/starkware/cairo/common/merkle_update.cairo create mode 100644 src/starkware/cairo/common/registers.cairo create mode 100644 src/starkware/cairo/common/serialize.cairo create mode 100644 src/starkware/cairo/common/signature.cairo create mode 100644 src/starkware/cairo/lang/CMakeLists.txt create mode 100644 src/starkware/cairo/lang/MANIFEST.in create mode 100644 src/starkware/cairo/lang/__init__.py create mode 100644 src/starkware/cairo/lang/builtins/CMakeLists.txt create mode 100644 src/starkware/cairo/lang/builtins/builtin_runner_test_utils.py create mode 100644 src/starkware/cairo/lang/builtins/checkpoints/checkpoints_builtin_runner.py create mode 100644 src/starkware/cairo/lang/builtins/checkpoints/instance_def.py create mode 100644 src/starkware/cairo/lang/builtins/hash/hash_builtin_runner.py create mode 100644 src/starkware/cairo/lang/builtins/hash/instance_def.py create mode 100644 src/starkware/cairo/lang/builtins/range_check/instance_def.py create mode 100644 src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner.py create mode 100644 src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner_test.py create mode 100644 src/starkware/cairo/lang/builtins/signature/instance_def.py create mode 100644 src/starkware/cairo/lang/builtins/signature/signature_builtin_runner.py create mode 100644 src/starkware/cairo/lang/builtins/signature/signature_builtin_runner_test.py create mode 100644 src/starkware/cairo/lang/compiler/CMakeLists.txt create mode 100644 src/starkware/cairo/lang/compiler/__init__.py create mode 100644 src/starkware/cairo/lang/compiler/assembler.py create mode 100644 src/starkware/cairo/lang/compiler/assembler_test.py create mode 100644 src/starkware/cairo/lang/compiler/ast/__init__.py create mode 100644 src/starkware/cairo/lang/compiler/ast/arguments.py create mode 100644 src/starkware/cairo/lang/compiler/ast/bool_expr.py create mode 100644 src/starkware/cairo/lang/compiler/ast/cairo_types.py create mode 100644 src/starkware/cairo/lang/compiler/ast/code_elements.py create mode 100644 src/starkware/cairo/lang/compiler/ast/expr.py create mode 100644 src/starkware/cairo/lang/compiler/ast/formatting_utils.py create mode 100644 src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py create mode 100644 src/starkware/cairo/lang/compiler/ast/instructions.py create mode 100644 src/starkware/cairo/lang/compiler/ast/module.py create mode 100644 src/starkware/cairo/lang/compiler/ast/node.py create mode 100644 src/starkware/cairo/lang/compiler/ast/notes.py create mode 100644 src/starkware/cairo/lang/compiler/ast/rvalue.py create mode 100644 src/starkware/cairo/lang/compiler/ast/types.py create mode 100644 src/starkware/cairo/lang/compiler/ast/visitor.py create mode 100644 src/starkware/cairo/lang/compiler/ast_objects_test.py create mode 100644 src/starkware/cairo/lang/compiler/cairo.ebnf create mode 100644 src/starkware/cairo/lang/compiler/cairo_compile.py create mode 100644 src/starkware/cairo/lang/compiler/cairo_compile_test.py create mode 100644 src/starkware/cairo/lang/compiler/cairo_format.py create mode 100644 src/starkware/cairo/lang/compiler/const_expr_checker.py create mode 100644 src/starkware/cairo/lang/compiler/constants.py create mode 100644 src/starkware/cairo/lang/compiler/debug_info.py create mode 100644 src/starkware/cairo/lang/compiler/encode.py create mode 100644 src/starkware/cairo/lang/compiler/encode_test.py create mode 100644 src/starkware/cairo/lang/compiler/error_handling.py create mode 100644 src/starkware/cairo/lang/compiler/error_handling_test.py create mode 100644 src/starkware/cairo/lang/compiler/expression_evaluator.py create mode 100644 src/starkware/cairo/lang/compiler/expression_evaluator_test.py create mode 100644 src/starkware/cairo/lang/compiler/expression_simplifier.py create mode 100644 src/starkware/cairo/lang/compiler/expression_simplifier_test.py create mode 100644 src/starkware/cairo/lang/compiler/expression_transformer.py create mode 100644 src/starkware/cairo/lang/compiler/fields.py create mode 100644 src/starkware/cairo/lang/compiler/identifier_definition.py create mode 100644 src/starkware/cairo/lang/compiler/identifier_definition_test.py create mode 100644 src/starkware/cairo/lang/compiler/identifier_manager.py create mode 100644 src/starkware/cairo/lang/compiler/identifier_manager_field.py create mode 100644 src/starkware/cairo/lang/compiler/identifier_manager_field_test.py create mode 100644 src/starkware/cairo/lang/compiler/identifier_manager_test.py create mode 100644 src/starkware/cairo/lang/compiler/identifier_utils.py create mode 100644 src/starkware/cairo/lang/compiler/identifier_utils_test.py create mode 100644 src/starkware/cairo/lang/compiler/import_loader.py create mode 100644 src/starkware/cairo/lang/compiler/import_loader_test.py create mode 100644 src/starkware/cairo/lang/compiler/instruction.py create mode 100644 src/starkware/cairo/lang/compiler/instruction_builder.py create mode 100644 src/starkware/cairo/lang/compiler/instruction_builder_test.py create mode 100644 src/starkware/cairo/lang/compiler/instruction_test.py create mode 100644 src/starkware/cairo/lang/compiler/location_utils.py create mode 100644 src/starkware/cairo/lang/compiler/module_reader.py create mode 100644 src/starkware/cairo/lang/compiler/module_reader_test.py create mode 100644 src/starkware/cairo/lang/compiler/parser.py create mode 100644 src/starkware/cairo/lang/compiler/parser_errors_test.py create mode 100644 src/starkware/cairo/lang/compiler/parser_test.py create mode 100644 src/starkware/cairo/lang/compiler/parser_transformer.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/__init__.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/compound_expressions.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/conftest.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/flow.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/flow_test.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/identifier_collector_test.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/local_variables.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/local_variables_test.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/preprocessor_error.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test_utils.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/preprocessor_utils.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/reg_tracking.py create mode 100644 src/starkware/cairo/lang/compiler/preprocessor/reg_tracking_test.py create mode 100644 src/starkware/cairo/lang/compiler/program.py create mode 100644 src/starkware/cairo/lang/compiler/references.py create mode 100644 src/starkware/cairo/lang/compiler/references_test.py create mode 100644 src/starkware/cairo/lang/compiler/scoped_name.py create mode 100644 src/starkware/cairo/lang/compiler/scoped_name_test.py create mode 100644 src/starkware/cairo/lang/compiler/substitute_identifiers.py create mode 100644 src/starkware/cairo/lang/compiler/test_utils.py create mode 100644 src/starkware/cairo/lang/compiler/type_system_visitor.py create mode 100644 src/starkware/cairo/lang/compiler/type_system_visitor_test.py create mode 100644 src/starkware/cairo/lang/ide/vim/ftdetect/cairo.vim create mode 100644 src/starkware/cairo/lang/ide/vim/ftplugin/cairo.vim create mode 100644 src/starkware/cairo/lang/ide/vim/syntax/cairo.vim create mode 100644 src/starkware/cairo/lang/ide/vscode-cairo/.gitignore create mode 100644 src/starkware/cairo/lang/ide/vscode-cairo/README.md create mode 100644 src/starkware/cairo/lang/ide/vscode-cairo/language-configuration.json create mode 100644 src/starkware/cairo/lang/ide/vscode-cairo/package.json create mode 100644 src/starkware/cairo/lang/ide/vscode-cairo/src/extension.ts create mode 100644 src/starkware/cairo/lang/ide/vscode-cairo/syntaxes/cairo.tmLanguage.json create mode 100644 src/starkware/cairo/lang/ide/vscode-cairo/tsconfig.json create mode 100644 src/starkware/cairo/lang/instances.py create mode 100644 src/starkware/cairo/lang/scripts/CMakeLists.txt create mode 100755 src/starkware/cairo/lang/scripts/cairo-compile create mode 100755 src/starkware/cairo/lang/scripts/cairo-format create mode 100755 src/starkware/cairo/lang/scripts/cairo-run create mode 100644 src/starkware/cairo/lang/setup.py create mode 100644 src/starkware/cairo/lang/tracer/CMakeLists.txt create mode 100644 src/starkware/cairo/lang/tracer/favicon.png create mode 100644 src/starkware/cairo/lang/tracer/index.html create mode 100644 src/starkware/cairo/lang/tracer/tracer.css create mode 100644 src/starkware/cairo/lang/tracer/tracer.js create mode 100755 src/starkware/cairo/lang/tracer/tracer.py create mode 100644 src/starkware/cairo/lang/tracer/tracer_data.py create mode 100644 src/starkware/cairo/lang/tracer/tracer_data_test.py create mode 100644 src/starkware/cairo/lang/vm/CMakeLists.txt create mode 100644 src/starkware/cairo/lang/vm/__init__.py create mode 100644 src/starkware/cairo/lang/vm/air_public_input.py create mode 100644 src/starkware/cairo/lang/vm/builtin_runner.py create mode 100644 src/starkware/cairo/lang/vm/cairo_pie.py create mode 100644 src/starkware/cairo/lang/vm/cairo_pie_test.py create mode 100644 src/starkware/cairo/lang/vm/cairo_run.py create mode 100644 src/starkware/cairo/lang/vm/cairo_runner.py create mode 100644 src/starkware/cairo/lang/vm/cairo_runner_test.py create mode 100644 src/starkware/cairo/lang/vm/crypto.py create mode 100644 src/starkware/cairo/lang/vm/memory_dict.py create mode 100644 src/starkware/cairo/lang/vm/memory_dict_test.py create mode 100644 src/starkware/cairo/lang/vm/memory_segments.py create mode 100644 src/starkware/cairo/lang/vm/memory_segments_test.py create mode 100644 src/starkware/cairo/lang/vm/output_builtin_runner.py create mode 100644 src/starkware/cairo/lang/vm/output_builtin_runner_test.py create mode 100644 src/starkware/cairo/lang/vm/relocatable.py create mode 100644 src/starkware/cairo/lang/vm/relocatable_fields.py create mode 100644 src/starkware/cairo/lang/vm/relocatable_fields_test.py create mode 100644 src/starkware/cairo/lang/vm/relocatable_test.py create mode 100644 src/starkware/cairo/lang/vm/security.py create mode 100644 src/starkware/cairo/lang/vm/security_test.py create mode 100644 src/starkware/cairo/lang/vm/test.cairo create mode 100644 src/starkware/cairo/lang/vm/trace_entry.py create mode 100644 src/starkware/cairo/lang/vm/trace_entry_test.py create mode 100644 src/starkware/cairo/lang/vm/utils.py create mode 100644 src/starkware/cairo/lang/vm/validated_memory_dict.py create mode 100644 src/starkware/cairo/lang/vm/validated_memory_dict_test.py create mode 100644 src/starkware/cairo/lang/vm/vm.py create mode 100644 src/starkware/cairo/lang/vm/vm_consts.py create mode 100644 src/starkware/cairo/lang/vm/vm_consts_test.py create mode 100644 src/starkware/cairo/lang/vm/vm_test.py create mode 100644 src/starkware/python/CMakeLists.txt create mode 100644 src/starkware/python/__init__.py create mode 100644 src/starkware/python/expression_string.py create mode 100644 src/starkware/python/expression_string_test.py create mode 100644 src/starkware/python/math_utils.py create mode 100644 src/starkware/python/math_utils_test.py create mode 100644 src/starkware/python/python_dependencies.py create mode 100644 src/starkware/python/test_utils.py create mode 100644 src/starkware/python/test_utils_test.py create mode 100644 src/starkware/python/utils.py create mode 100644 src/starkware/python/utils_test.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..0d8ac0aa --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/build/ +__pycache__/ +cairo-starkware-*.zip diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..d5b172eb --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "repos/starkware-public"] + path = repos/starkware-public + url = git@github.com:starkware-libs/starkex-resources-wip.git diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..fc23846f --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,24 @@ +cmake_minimum_required (VERSION 3.5) + +project(CairoLang VERSION 0.1.0) +include(CTest) + +enable_testing() + +if (NOT DEFINED CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Debug) +endif() + +# Python library macro. +find_program(PYTHON "python3") + +include("repos/starkware-public/cmake_utils/exe_rules.cmake") +include("repos/starkware-public/cmake_utils/python_rules.cmake") +include("repos/starkware-public/cmake_utils/pip_rules.cmake") +python_get_pip_deps(main_reqs + python3.7:${CMAKE_SOURCE_DIR}/scripts/requirements-deps.json +) + +# Repos needs to be first as it defines macros that are needed by src. +add_subdirectory(repos) +add_subdirectory(src) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..97ce9ed0 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,22 @@ +FROM ubuntu:18.04 + +RUN apt update +RUN apt install -y cmake python3.7 libgmp3-dev g++ python3-pip python3.7-dev npm + +COPY . /app/ + +# Build. +WORKDIR /app/ +RUN ./build.sh + +WORKDIR /app/build/Release +RUN make all -j8 + +# Run tests. +RUN ctest -V + +# Build the Visual Studio Code extension. +WORKDIR /app/src/starkware/cairo/lang/ide/vscode-cairo +RUN npm install -g vsce +RUN npm install +RUN vsce package diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 00000000..e8048fa3 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,15 @@ +The files in this repository are governed by several licenses, +which are located under the licenses/ directory. + +All code in this project except for the files under +'src/starkware/cairo/apps/starkex2_0/' and its subdirectories is +subject to the Cairo Toolchain License (Source Available), which +can be found under 'licenses/CairoToolchainLicense.txt'. + +The code under 'src/starkware/cairo/apps/starkex2_0/' and its +subdirectories is subject to the Cairo Program License (Source +Available), which can be found under +'licenses/CairoProgramLicense.txt'. + +For more infromation regarding licenses visit +https://starkware.co/licenses/. diff --git a/README.md b/README.md new file mode 100644 index 00000000..a3138635 --- /dev/null +++ b/README.md @@ -0,0 +1,57 @@ +# Introduction + +[Cairo](https://cairo-lang.org/) is a programming language for writing provable programs. + +# Documentation + +The Cairo documentation consists of two parts: "Hello Cairo" and "How Cairo Works?". +Both parts can be found in https://cairo-lang.org/docs/. + +We recommend starting from [Setting up the environment](https://cairo-lang.org/docs/quickstart.html). + +# Installation instructions + +You should be able to download the python package zip file directly from +[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.0.1) +and install it using ``pip``. +See [Setting up the environment](https://cairo-lang.org/docs/quickstart.html). + +However, if you want to build it yourself, you can build it from the git repository. +It is recommended to run the build inside a docker (as explained below), +since it guarantees that all the dependencies +are installed. Alternatively, you can try following the commands in the +[docker file](https://github.com/starkware-libs/cairo-lang/blob/master/Dockerfile). + +## Building using the dockerfile + +The root directory holds a dedicated Dockerfile, which automatically builds the package and runs +the unit tests on a simulated Ubuntu 18.04 environment. + +Clone the repository and initialize the git submodules using: + +```bash +> git clone git@github.com:starkware-libs/cairo-lang.git +> cd cairo-lang +> git submodule update --init +``` + +Build the docker image: + +```bash +> docker build --tag cairo . +``` + +If everything works, you should see + +```bash +Successfully tagged cairo:latest +``` + +Once the docker image is built, you can fetch the python package zip file using: + +```bash +> container_id=$(docker create cairo) +> docker cp ${container_id}:/app/cairo-starkware-0.0.1.zip . +> docker rm -v ${container_id} +``` + diff --git a/build.sh b/build.sh new file mode 100755 index 00000000..0a83f0db --- /dev/null +++ b/build.sh @@ -0,0 +1,15 @@ +set -e + +mkdir -p build/Release +( + cd build/Release + cmake ../.. -DCMAKE_BUILD_TYPE=Release + make -j8 cairo_lang_venv +) + +VENV_SITE_DIR=build/Release/src/starkware/cairo/lang/cairo_lang_venv-site +cp src/starkware/cairo/lang/setup.py ${VENV_SITE_DIR} +cp src/starkware/cairo/lang/MANIFEST.in ${VENV_SITE_DIR} +cp scripts/requirements-gen.txt ${VENV_SITE_DIR}/requirements.txt +( cd ${VENV_SITE_DIR}; python3 setup.py sdist --format=zip ) +cp ${VENV_SITE_DIR}/dist/cairo-starkware-0.0.1.zip . diff --git a/licenses/CairoProgramLicense.txt b/licenses/CairoProgramLicense.txt new file mode 100644 index 00000000..edf21f9c --- /dev/null +++ b/licenses/CairoProgramLicense.txt @@ -0,0 +1,37 @@ +Cairo Program License (Source Available) + +Version 1.0, November 2020 + +This license contains the terms and conditions under which StarkWare +Industries, Ltd ("StarkWare") makes available its StarkEx Cairo Software +("Cairo Software"). Your use of the Cairo Software is subject to these +terms and conditions. + +StarkWare grants you a license to use and distribute the Cairo Software +during the Test Period, only for writing Cairo programs. The Cairo Verifier +Smart Contract ("Verifier") is not part of the Cairo Software and is +subject to a separate license. The "Test Period" will end on June 30, 2021, +however, StarkWare may extend this date by posting a notice of extension +referencing this license on its web site +https://www.starkware.co/source-available-license/. + +These terms do not allow you to sublicense or transfer any of your rights to +anyone else. These terms do not imply any other licenses not expressly +granted in this license. + +If you violate any of these terms, use the Cairo Software in a way not +authorized under this license, your license ends immediately. If you make, +or authorize any other person to make, any written claim that the Cairo +Software infringes or contributes to infringement of any patent, all rights +you are granted under this license end immediately. In either case, for +purposes of your license, the Test Period will end. + +After the end of the Test Period, you will not be licensed to use the Cairo +Software to write Cairo programs, but you may retain and distribute copies of +the Cairo Software only as needed to allow Ethereum nodes to sync with the +Ethereum Mainnet and verify Ethereum transactions. + +As far as the law allows, the Cairo Software is provided AS IS, without any +warranty or condition, and StarkWare will not be liable to you for any damages +arising out of these terms or the use or nature of the Cairo Software and/or +the Verifier, under any kind of legal claim. diff --git a/licenses/CairoToolchainLicense.txt b/licenses/CairoToolchainLicense.txt new file mode 100644 index 00000000..275b21c8 --- /dev/null +++ b/licenses/CairoToolchainLicense.txt @@ -0,0 +1,35 @@ +Cairo Toolchain License (Source Available) + +Version 1.0 dated December 22, 2020 + +This license contains the terms and conditions under which StarkWare +Industries, Ltd ("StarkWare") makes available its Cairo Toolchain +("Toolchain"). Your use of the Toolchain is subject to these terms and +conditions. + +StarkWare grants you ("Licensee") a license to use the Toolchain, only +for the purpose of developing and compiling Cairo programs. Licensee's +use of the Toolchain is limited to non-commercial use, which means academic, +scientific, or research and development use, or evaluating the Cairo +language and Toolchain. + +StarkWare grants Licensee a license to modify the Toolchain, only as +necessary to fix errors. Licensee may, but is not obligated to, provide +any of Licensee's modifications to StarkWare. This license grants Licensee no +right to distribute the Toolchain or make copies of the Toolchain available +to others. + +These terms do not allow Licensee to sublicense or transfer any of Licensee's +rights to anyone else. These terms do not imply any other licenses not +expressly granted in this license. + +If Licensee violates any of these terms, or uses the Toolchain in a way not +authorized under this license, the license granted to Licensee ends immediately. +If Licensee makes, or authorizes any other person to make, any written claim +that the Toolchain infringes or contributes to infringement of any patent, all +rights granted to Licensee under this license end immediately. + +As far as the law allows, the Toolchain is provided AS IS, without any warranty +or condition, and StarkWare will not be liable to Licensee for any damages +arising out of these terms or the use or nature of the Toolchain, under any kind +of legal claim. diff --git a/repos/CMakeLists.txt b/repos/CMakeLists.txt new file mode 100644 index 00000000..77df6431 --- /dev/null +++ b/repos/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(starkware-public) diff --git a/repos/starkware-public b/repos/starkware-public new file mode 160000 index 00000000..24bb519d --- /dev/null +++ b/repos/starkware-public @@ -0,0 +1 @@ +Subproject commit 24bb519dca808591e1c069ff9d46bdcb370b3fd1 diff --git a/scripts/requirements-deps.json b/scripts/requirements-deps.json new file mode 100644 index 00000000..5f2a667e --- /dev/null +++ b/scripts/requirements-deps.json @@ -0,0 +1,357 @@ +[ + { + "dependencies": [], + "package": { + "installed_version": "20.2.0", + "key": "attrs", + "package_name": "attrs" + } + }, + { + "dependencies": [ + { + "installed_version": "1.15.0", + "key": "six", + "package_name": "six", + "required_version": ">=1.9.0" + } + ], + "package": { + "installed_version": "0.16.0", + "key": "ecdsa", + "package_name": "ecdsa" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "2.1.5", + "key": "fastecdsa", + "package_name": "fastecdsa" + } + }, + { + "dependencies": [ + { + "installed_version": "3.4.0", + "key": "zipp", + "package_name": "zipp", + "required_version": ">=0.5" + } + ], + "package": { + "installed_version": "2.0.0", + "key": "importlib-metadata", + "package_name": "importlib-metadata" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "1.1.1", + "key": "iniconfig", + "package_name": "iniconfig" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "0.8.5", + "key": "lark-parser", + "package_name": "lark-parser" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "3.8.0", + "key": "marshmallow", + "package_name": "marshmallow" + } + }, + { + "dependencies": [ + { + "installed_version": "3.8.0", + "key": "marshmallow", + "package_name": "marshmallow", + "required_version": ">=3.0.0,<4.0" + }, + { + "installed_version": "0.6.0", + "key": "typing-inspect", + "package_name": "typing-inspect", + "required_version": null + } + ], + "package": { + "installed_version": "8.1.0", + "key": "marshmallow-dataclass", + "package_name": "marshmallow-dataclass" + } + }, + { + "dependencies": [ + { + "installed_version": "3.8.0", + "key": "marshmallow", + "package_name": "marshmallow", + "required_version": ">=2.0.0" + } + ], + "package": { + "installed_version": "1.5.1", + "key": "marshmallow-enum", + "package_name": "marshmallow-enum" + } + }, + { + "dependencies": [ + { + "installed_version": "3.8.0", + "key": "marshmallow", + "package_name": "marshmallow", + "required_version": ">=3.0.0rc6,<4.0.0" + } + ], + "package": { + "installed_version": "2.1.0", + "key": "marshmallow-oneofschema", + "package_name": "marshmallow-oneofschema" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "1.1.0", + "key": "mpmath", + "package_name": "mpmath" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "0.4.3", + "key": "mypy-extensions", + "package_name": "mypy-extensions" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "1.19.2", + "key": "numpy", + "package_name": "numpy" + } + }, + { + "dependencies": [ + { + "installed_version": "2.4.7", + "key": "pyparsing", + "package_name": "pyparsing", + "required_version": ">=2.0.2" + }, + { + "installed_version": "1.15.0", + "key": "six", + "package_name": "six", + "required_version": null + } + ], + "package": { + "installed_version": "20.4", + "key": "packaging", + "package_name": "packaging" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "20.1.1", + "key": "pip", + "package_name": "pip" + } + }, + { + "dependencies": [ + { + "installed_version": "20.1.1", + "key": "pip", + "package_name": "pip", + "required_version": ">=6.0.0" + } + ], + "package": { + "installed_version": "1.0.0", + "key": "pipdeptree", + "package_name": "pipdeptree" + } + }, + { + "dependencies": [ + { + "installed_version": "2.0.0", + "key": "importlib-metadata", + "package_name": "importlib-metadata", + "required_version": ">=0.12" + } + ], + "package": { + "installed_version": "0.13.1", + "key": "pluggy", + "package_name": "pluggy" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "1.9.0", + "key": "py", + "package_name": "py" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "2.4.7", + "key": "pyparsing", + "package_name": "pyparsing" + } + }, + { + "dependencies": [ + { + "installed_version": "20.2.0", + "key": "attrs", + "package_name": "attrs", + "required_version": ">=17.4.0" + }, + { + "installed_version": "2.0.0", + "key": "importlib-metadata", + "package_name": "importlib-metadata", + "required_version": ">=0.12" + }, + { + "installed_version": "1.1.1", + "key": "iniconfig", + "package_name": "iniconfig", + "required_version": null + }, + { + "installed_version": "20.4", + "key": "packaging", + "package_name": "packaging", + "required_version": null + }, + { + "installed_version": "0.13.1", + "key": "pluggy", + "package_name": "pluggy", + "required_version": ">=0.12,<1.0" + }, + { + "installed_version": "1.9.0", + "key": "py", + "package_name": "py", + "required_version": ">=1.8.2" + }, + { + "installed_version": "0.10.1", + "key": "toml", + "package_name": "toml", + "required_version": null + } + ], + "package": { + "installed_version": "6.1.1", + "key": "pytest", + "package_name": "pytest" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "47.1.1", + "key": "setuptools", + "package_name": "setuptools" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "1.15.0", + "key": "six", + "package_name": "six" + } + }, + { + "dependencies": [ + { + "installed_version": "1.1.0", + "key": "mpmath", + "package_name": "mpmath", + "required_version": ">=0.19" + } + ], + "package": { + "installed_version": "1.6.2", + "key": "sympy", + "package_name": "sympy" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "0.10.1", + "key": "toml", + "package_name": "toml" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "3.7.4.3", + "key": "typing-extensions", + "package_name": "typing-extensions" + } + }, + { + "dependencies": [ + { + "installed_version": "0.4.3", + "key": "mypy-extensions", + "package_name": "mypy-extensions", + "required_version": ">=0.3.0" + }, + { + "installed_version": "3.7.4.3", + "key": "typing-extensions", + "package_name": "typing-extensions", + "required_version": ">=3.7.4" + } + ], + "package": { + "installed_version": "0.6.0", + "key": "typing-inspect", + "package_name": "typing-inspect" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "0.34.2", + "key": "wheel", + "package_name": "wheel" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "3.4.0", + "key": "zipp", + "package_name": "zipp" + } + } +] \ No newline at end of file diff --git a/scripts/requirements-gen.txt b/scripts/requirements-gen.txt new file mode 100644 index 00000000..0e4cf1e4 --- /dev/null +++ b/scripts/requirements-gen.txt @@ -0,0 +1,12 @@ +ecdsa +fastecdsa +lark-parser==0.8.5 +marshmallow-dataclass>=7.1.0 +marshmallow-enum +marshmallow-oneofschema +marshmallow>=3.2.1 +mpmath +numpy +pipdeptree +pytest +sympy diff --git a/scripts/requirements.txt b/scripts/requirements.txt new file mode 100644 index 00000000..27932dac --- /dev/null +++ b/scripts/requirements.txt @@ -0,0 +1,27 @@ +# This file is autogenerated. Do not edit manually. + +attrs==20.2.0 +ecdsa==0.16.0 +fastecdsa==2.1.5 +importlib-metadata==2.0.0 +iniconfig==1.1.1 +lark-parser==0.8.5 +marshmallow==3.8.0 +marshmallow-dataclass==8.1.0 +marshmallow-enum==1.5.1 +marshmallow-oneofschema==2.1.0 +mpmath==1.1.0 +mypy-extensions==0.4.3 +numpy==1.19.2 +packaging==20.4 +pipdeptree==1.0.0 +pluggy==0.13.1 +py==1.9.0 +pyparsing==2.4.7 +pytest==6.1.1 +six==1.15.0 +sympy==1.6.2 +toml==0.10.1 +typing-extensions==3.7.4.3 +typing-inspect==0.6.0 +zipp==3.4.0 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 00000000..9f9a94e4 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(starkware) diff --git a/src/starkware/CMakeLists.txt b/src/starkware/CMakeLists.txt new file mode 100644 index 00000000..0d5b491f --- /dev/null +++ b/src/starkware/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(cairo) +add_subdirectory(python) diff --git a/src/starkware/cairo/CMakeLists.txt b/src/starkware/cairo/CMakeLists.txt new file mode 100644 index 00000000..bfdbf2c9 --- /dev/null +++ b/src/starkware/cairo/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(apps) +add_subdirectory(bootloader) +add_subdirectory(common) +add_subdirectory(lang) diff --git a/src/starkware/cairo/apps/CMakeLists.txt b/src/starkware/cairo/apps/CMakeLists.txt new file mode 100644 index 00000000..3013e9fb --- /dev/null +++ b/src/starkware/cairo/apps/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(starkex2_0) diff --git a/src/starkware/cairo/apps/starkex2_0/CMakeLists.txt b/src/starkware/cairo/apps/starkex2_0/CMakeLists.txt new file mode 100644 index 00000000..a986645c --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/CMakeLists.txt @@ -0,0 +1,40 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +full_python_test(starkex2_0_program_test + PREFIX starkware/cairo/apps/starkex2_0 + PYTHON python3.7 + TESTED_MODULES starkware/cairo/apps/starkex2_0 + + FILES + common/cairo_builtins.cairo + common/dict.cairo + common/merkle_multi_update.cairo + common/merkle_update.cairo + common/registers.cairo + dex_constants.cairo + dex_context.cairo + execute_batch.cairo + execute_false_full_withdrawal.cairo + execute_limit_order.cairo + execute_modification.cairo + execute_settlement.cairo + execute_transfer.cairo + hash_vault_ptr_dict.cairo + main.cairo + starkex2_0_program_test.py + vault_update.cairo + verify_order_id.cairo + verify_order_signature.cairo + + LIBS + starkware_python_utils_lib + pip_pytest + + PY_EXE_DEPENDENCIES + cairo_compile_exe + cairo_hash_program_exe +) + diff --git a/src/starkware/cairo/apps/starkex2_0/common/cairo_builtins.cairo b/src/starkware/cairo/apps/starkex2_0/common/cairo_builtins.cairo new file mode 100644 index 00000000..7b7bbc88 --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/common/cairo_builtins.cairo @@ -0,0 +1,27 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +# A representation of a HashBuiltin struct, specifying the hash builtin memory structure. +struct HashBuiltin: + member x = 0 + member y = 1 + member result = 2 + const SIZE = 3 +end + +# A representation of a SignatureBuiltin struct, specifying the signature builtin memory structure. +struct SignatureBuiltin: + member pub_key = 0 + member message = 1 + const SIZE = 2 +end + +# A representation of a CheckpointsBuiltin struct, specifying the checkpoints builtin memory +# structure. +struct CheckpointsBuiltin: + member required_pc = 0 + member required_fp = 1 + const SIZE = 2 +end diff --git a/src/starkware/cairo/apps/starkex2_0/common/dict.cairo b/src/starkware/cairo/apps/starkex2_0/common/dict.cairo new file mode 100644 index 00000000..d7d783b9 --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/common/dict.cairo @@ -0,0 +1,219 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +struct DictAccess: + member key = 0 + member prev_value = 1 + member new_value = 2 + const SIZE = 3 +end + +# Inner tail-recursive function for squash_dict. +# +# Arguments: +# range_check_ptr - range check builtin pointer. +# dict_accesses - a pointer to the beginning of an array of DictAccess instances. +# dict_accesses_end_minus1 - a pointer to the end of said array, minus 1. +# min_key - minimum allowed key. Used to enforce monotonicity of keys. +# remaining_accesses - remaining number of accesses that need to be accounted for. Starts with +# the total number of entries in dict_accesses array, and slowly decreases until it reaches 0. +# squashed_dict - a pointer to an output array, which will be filled with +# DictAccess instances sorted by key with the first and last value for each key. +# +# Hints: +# keys - a descending list of the keys for which we have accesses. Destroyed in the process. +# access_indices - A map from key to a descending list of indices in the dict_accesses array that +# access this key. Destroyed in the process. +# +# Returns: +# range_check_ptr - updated range check builtin pointer. +# squashed_dict - end pointer to squashed_dict. +func squash_dict_inner( + range_check_ptr, dict_accesses : DictAccess*, dict_accesses_end_minus1 : felt*, min_key, + remaining_accesses, squashed_dict : DictAccess*) -> ( + range_check_ptr, squashed_dict : DictAccess*): + # Exit recursion when done. + if remaining_accesses == 0: + %{ assert len(keys) == 0 %} + return (range_check_ptr=range_check_ptr, squashed_dict=squashed_dict) + end + + # Locals. + struct Locals: + member key = 0 + member should_skip_loop = 1 + member first_value = 2 + const SIZE = 3 + end + let locals = cast(fp, Locals*) + let key = locals.key + let dict_diff : DictAccess* = squashed_dict + ap += Locals.SIZE + + # Guess key and check that key >= min_key. + %{ ids.locals.key = key = keys.pop() %} + [ap] = key - min_key + [ap] = [range_check_ptr]; ap++ + + # Loop to verify chronological accesses to the key. + # These values are not needed from previous iteration. + struct LoopTemps: + member index_delta_minus1 = 0 + member index_delta = 1 + member ptr_delta = 2 + member should_continue = 3 + const SIZE = 4 + end + # These values are needed from previous iteration. + struct LoopLocals: + member value = 0 + member access_ptr : DictAccess* = 1 + member range_check_ptr = 2 + const SIZE = 3 + end + + # Prepare first iteration. + %{ + current_access_indices = sorted(access_indices[key])[::-1] + current_access_index = current_access_indices.pop() + memory[ids.range_check_ptr + 1] = current_access_index + %} + # Check that first access_index >= 0. + tempvar current_access_index = [range_check_ptr + 1] + tempvar ptr_delta = current_access_index * DictAccess.SIZE + + let first_loop_locals = cast(ap, LoopLocals*) + first_loop_locals.access_ptr = dict_accesses + ptr_delta; ap++ + let first_access : DictAccess* = first_loop_locals.access_ptr + first_loop_locals.value = first_access.new_value; ap++ + first_loop_locals.range_check_ptr = range_check_ptr + 2; ap++ + + # Verify first key. + key = first_access.key + + # Write key and first value to dict_diff. + key = dict_diff.key + # Use a local variable, instead of a tempvar, to avoid increasing ap. + locals.first_value = first_access.prev_value + locals.first_value = dict_diff.prev_value + + # Skip loop non-deterministically if necessary. + %{ memory[fp + ids.Locals.should_skip_loop] = 0 if current_access_indices else 1 %} + jmp skip_loop if [fp + Locals.should_skip_loop] != 0 + + loop: + let prev_loop_locals = cast(ap - LoopLocals.SIZE, LoopLocals*) + let loop_temps = cast(ap, LoopTemps*) + let loop_locals = cast(ap + LoopTemps.SIZE, LoopLocals*) + + # Check access_index. + %{ + new_access_index = current_access_indices.pop() + ids.loop_temps.index_delta_minus1 = new_access_index - current_access_index - 1 + current_access_index = new_access_index + %} + # Check that new access_index > prev access_index. + loop_temps.index_delta_minus1 = [prev_loop_locals.range_check_ptr]; ap++ + loop_temps.index_delta = loop_temps.index_delta_minus1 + 1; ap++ + loop_temps.ptr_delta = loop_temps.index_delta * DictAccess.SIZE; ap++ + loop_locals.access_ptr = prev_loop_locals.access_ptr + loop_temps.ptr_delta; ap++ + + # Check valid transition. + let access : DictAccess* = loop_locals.access_ptr + prev_loop_locals.value = access.prev_value + loop_locals.value = access.new_value; ap++ + + # Verify key. + key = access.key + + # Next range_check_ptr. + loop_locals.range_check_ptr = prev_loop_locals.range_check_ptr + 1; ap++ + + %{ ids.loop_temps.should_continue = 1 if current_access_indices else 0 %} + jmp loop if loop_temps.should_continue != 0; ap++ + + skip_loop: + let last_loop_locals = cast(ap - LoopLocals.SIZE, LoopLocals*) + + # Check if address is out of bounds. + %{ assert len(current_access_indices) == 0 %} + [ap] = dict_accesses_end_minus1 - cast(last_loop_locals.access_ptr, felt) + [ap] = [last_loop_locals.range_check_ptr]; ap++ + tempvar range_check_diff = last_loop_locals.range_check_ptr - range_check_ptr + tempvar n_used_accesses = range_check_diff - 1 + %{ assert ids.n_used_accesses == len(access_indices[key]) %} + + # Write last value to dict_diff. + last_loop_locals.value = dict_diff.new_value + + # Call squashed_dict_inner recursively. + squash_dict_inner( + range_check_ptr=last_loop_locals.range_check_ptr + 1, + dict_accesses=dict_accesses, + dict_accesses_end_minus1=dict_accesses_end_minus1, + min_key=key + 1, + remaining_accesses=remaining_accesses - n_used_accesses, + squashed_dict=squashed_dict + DictAccess.SIZE) + return (...) +end + +# Verifies that dict_accesses lists valid chronological accesses (and updates) +# to a mutable dictionary and outputs a squashed dict with one DictAccess instance per key +# (value before and value after) which summarizes all the changes to that key. +# +# All keys are assumed to be in the range of the range check builtin (usually 2**128). +# +# Example: +# Input: {(key1, 0, 2), (key1, 2, 7), (key2, 4, 1), (key1, 7, 5), (key2, 1, 2)} +# Output: {(key1, 0, 5), (key2, 4, 2)} +# +# Arguments: +# range_check_ptr - range check builtin pointer. +# dict_accesses - a pointer to the beginning of an array of DictAccess instances. The format of each +# entry is a triplet (key, prev_value, new_value). +# dict_accesses_end - a pointer to the end of said array. +# squashed_dict - a pointer to an output array, which will be filled with +# DictAccess instances sorted by key with the first and last value for each key. +# +# Returns: +# range_check_ptr - updated range check builtin pointer. +# squashed_dict - end pointer to squashed_dict. +func squash_dict( + range_check_ptr, dict_accesses : DictAccess*, dict_accesses_end : DictAccess*, + squashed_dict : DictAccess*) -> (range_check_ptr, squashed_dict : DictAccess*): + let ptr_diff = [fp] + %{ vm_enter_scope() %} + ptr_diff = dict_accesses_end - dict_accesses; ap++ + + if ptr_diff == 0: + # Access array is empty, nothing to check. + %{ vm_exit_scope() %} + return (range_check_ptr=range_check_ptr, squashed_dict=squashed_dict) + end + + tempvar n_accesses = ptr_diff / DictAccess.SIZE + %{ + assert ids.ptr_diff % ids.DictAccess.SIZE == 0, \ + 'Accesses array size must be divisible by DictAccess.SIZE' + # A map from key to the list of indices accessing it. + access_indices = {} + for i in range(ids.n_accesses): + key = memory[ids.dict_accesses.address_ + ids.DictAccess.SIZE * i] + access_indices.setdefault(key, []).append(i) + # Descending list of keys. + keys = sorted(access_indices.keys())[::-1] + %} + + # Call inner. + squash_dict_inner( + range_check_ptr=range_check_ptr, + dict_accesses=dict_accesses, + dict_accesses_end_minus1=dict_accesses_end - 1, + min_key=0, + remaining_accesses=n_accesses, + squashed_dict=squashed_dict) + %{ vm_exit_scope() %} + return (...) +end diff --git a/src/starkware/cairo/apps/starkex2_0/common/merkle_multi_update.cairo b/src/starkware/cairo/apps/starkex2_0/common/merkle_multi_update.cairo new file mode 100644 index 00000000..bc7e4378 --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/common/merkle_multi_update.cairo @@ -0,0 +1,190 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin +from starkware.cairo.apps.starkex2_0.common.dict import DictAccess + +# Helper function for merkle_multi_update(). +func merkle_multi_update_inner( + hash_ptr : HashBuiltin*, update_ptr : DictAccess*, height, prev_root, new_root, index) -> ( + hash_ptr : HashBuiltin*, update_ptr : DictAccess*): + let hash0 : HashBuiltin* = hash_ptr + let hash1 : HashBuiltin* = hash_ptr + HashBuiltin.SIZE + %{ + if ids.height == 0: + assert node == ids.new_root, f'Expected node {ids.new_root}. Got {node}.' + case = 'leaf' + else: + prev_left, prev_right = preimage[ids.prev_root] + new_left, new_right = preimage[ids.new_root] + + left_child, right_child = node + if left_child is None: + assert right_child is not None, 'No updates in tree' + case = 'right' + elif right_child is None: + case = 'left' + else: + case = 'both' + + # Fill non deterministic hashes. + hash_ptr = ids.hash_ptr.address_ + memory[hash_ptr + 0 * ids.HashBuiltin.SIZE + ids.HashBuiltin.x] = prev_left + memory[hash_ptr + 0 * ids.HashBuiltin.SIZE + ids.HashBuiltin.y] = prev_right + memory[hash_ptr + 1 * ids.HashBuiltin.SIZE + ids.HashBuiltin.x] = new_left + memory[hash_ptr + 1 * ids.HashBuiltin.SIZE + ids.HashBuiltin.y] = new_right + + memory[ap] = int(case != 'right') + %} + jmp not_right if [ap] != 0; ap++ + + update_right: + prev_root = hash0.result + new_root = hash1.result + + # Make sure the same authentication path is used. + assert hash0.x = hash1.x + + # Call merkle_multi_update_inner recursively. + %{ vm_enter_scope(dict(node=right_child, preimage=preimage)) %} + merkle_multi_update_inner( + hash_ptr=hash_ptr + 2 * HashBuiltin.SIZE, + update_ptr=update_ptr, + height=height - 1, + prev_root=hash0.y, + new_root=hash1.y, + index=index * 2 + 1) + %{ vm_exit_scope() %} + return (...) + + not_right: + %{ memory[ap] = int(case != 'left') %} + jmp not_left if [ap] != 0; ap++ + + update_left: + prev_root = hash0.result + new_root = hash1.result + + # Make sure the same authentication path is used. + assert hash0.y = hash1.y + + # Call merkle_multi_update_inner recursively. + %{ vm_enter_scope(dict(node=left_child, preimage=preimage)) %} + merkle_multi_update_inner( + hash_ptr=hash_ptr + 2 * HashBuiltin.SIZE, + update_ptr=update_ptr, + height=height - 1, + prev_root=hash0.x, + new_root=hash1.x, + index=index * 2) + %{ vm_exit_scope() %} + return (...) + + not_left: + jmp update_both if height != 0 + + update_leaf: + # Note: height may underflow, but in order to reach 0 (which is verified here), we will need + # more steps than the field characteristic. The assumption is that it is not feasible. + + # Write the update. + let update : DictAccess* = update_ptr + %{ assert case == 'leaf' %} + index = update.key + prev_root = update.prev_value + new_root = update.new_value + + # Return values. + return (hash_ptr=hash_ptr, update_ptr=update + DictAccess.SIZE) + + update_both: + # Locals 0 and 1 are taken by non deterministic jumps. + let local_left_index = [fp + 2] + %{ assert case == 'both' %} + local_left_index = index * 2; ap++ + + prev_root = hash0.result + new_root = hash1.result + + # Update left. + %{ vm_enter_scope(dict(node=left_child, preimage=preimage)) %} + merkle_multi_update_inner( + hash_ptr=hash_ptr + 2 * HashBuiltin.SIZE, + update_ptr=update_ptr, + height=height - 1, + prev_root=hash0.x, + new_root=hash1.x, + index=index * 2) + %{ vm_exit_scope() %} + + # Update right. + # hash_ptr and update_ptr are already pushed. + # Push height to workaround one hint per line limitation. + [ap] = height - 1; ap++ # height. + %{ vm_enter_scope(dict(node=right_child, preimage=preimage)) %} + merkle_multi_update_inner(..., prev_root=hash0.y, new_root=hash1.y, index=local_left_index + 1) + %{ vm_exit_scope() %} + return (...) +end + +# Performs an efficient update of multiple leaves in a Merkle tree. +# +# Arguments: +# hash_ptr - hash builtin pointer. +# update_ptr - a list of DictAccess instances sorted by key (e.g., the result of squash_dict). +# height - height of merkle tree. +# prev_root - root value before the multi update. +# new_root - root value after the multi update. +# +# Hint arguments: +# preimage - a dictionary from the hash value of a merkle node to the pair of children values. +# +# Returns: +# hash_ptr - updated hash builtin pointer. +# +# Assumptions: The keys in the update_ptr list are unique and sorted. +# Guarantees: All the keys in the update_ptr list are < 2**height. +# +# Pseudocode: +# def diff(prev, new, height): +# if height == 0: return [(prev,new)] +# if prev.left==new.left: return diff(prev.right, new.right, height - 1) +# if prev.right==new.right: return diff(prev.left, new.left, height - 1) +# return diff(prev.left, new.left, height - 1) + \ +# diff(prev.right, new.right, height - 1) +func merkle_multi_update( + hash_ptr : HashBuiltin*, update_ptr : DictAccess*, n_updates, height, prev_root, + new_root) -> (hash_ptr : HashBuiltin*): + if n_updates == 0: + prev_root = new_root + return (hash_ptr=hash_ptr) + end + + %{ + from starkware.starkware_utils.merkle_tree.merkle_tree import build_update_tree + + # Build modifications list. + modifications = [] + for i in range(ids.n_updates): + curr_update_ptr = ids.update_ptr.address_ + i * ids.DictAccess.SIZE + modifications.append(( + memory[curr_update_ptr + ids.DictAccess.key], + memory[curr_update_ptr + ids.DictAccess.new_value])) + + node = build_update_tree(ids.height, modifications) + del modifications + vm_enter_scope(dict(node=node, preimage=preimage)) + %} + let ret_val = merkle_multi_update_inner( + hash_ptr=hash_ptr, + update_ptr=update_ptr, + height=height, + prev_root=prev_root, + new_root=new_root, + index=0) + assert ret_val.update_ptr = update_ptr + n_updates * DictAccess.SIZE + %{ vm_exit_scope() %} + return (hash_ptr=ret_val.hash_ptr) +end diff --git a/src/starkware/cairo/apps/starkex2_0/common/merkle_update.cairo b/src/starkware/cairo/apps/starkex2_0/common/merkle_update.cairo new file mode 100644 index 00000000..3fa5fc9a --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/common/merkle_update.cairo @@ -0,0 +1,84 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin + +# Performs an update for a single leaf (index) in a Merkle tree (where 0 <= index < 2^height). +# Updates the leaf from prev_leaf to new_leaf, and returns the previous and new roots of the +# Merkle tree resulting from the change. +# In particular, given a secret authentication path (of the siblings of the nodes in the path from +# the root to the leaf), this function computes the roots twice - once with prev_leaf and once with +# new_leaf, where the verifier is guaranteed that the same authentication path is used. +func merkle_update(hash_ptr, height, prev_leaf, new_leaf, index) -> (prev_root, new_root, hash_ptr): + if height == 0: + # Assert that index is 0. + index = 0 + # Return the two leaves and the Pedersen pointer. + %{ + # Check that auth_path had the right number of elements. + assert len(auth_path) == 0, 'Got too many values in auth_path.' + %} + return (prev_root=prev_leaf, new_root=new_leaf, hash_ptr=hash_ptr) + end + + %{ memory[ap] = ids.index % 2 %} + jmp update_right if [ap] != 0; ap++ + + update_left: + %{ + # Hash hints. + sibling = auth_path.pop() + memory[ids.hash_ptr + 0 * ids.HashBuiltin.SIZE + ids.HashBuiltin.y] = sibling + memory[ids.hash_ptr + 1 * ids.HashBuiltin.SIZE + ids.HashBuiltin.y] = sibling + %} + prev_leaf = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.x] + new_leaf = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.x] + + # Make sure the same authentication path is used. + let right_sibling = ap + [right_sibling] = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.y] + [right_sibling] = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.y]; ap++ + + # Call merkle_update recursively. + [ap] = hash_ptr + 2 * HashBuiltin.SIZE; ap++ # hash_ptr. + [ap] = height - 1; ap++ # height. + [ap] = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.result]; ap++ # prev_leaf. + [ap] = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.result]; ap++ # new_leaf. + + let update_left_index = ap + index = [update_left_index] * 2; ap++ # index. + merkle_update(...) # Tail recursion. + return (...) + + update_right: + %{ + # Hash hints. + sibling = auth_path.pop() + memory[ids.hash_ptr + 0 * ids.HashBuiltin.SIZE + ids.HashBuiltin.x] = sibling + memory[ids.hash_ptr + 1 * ids.HashBuiltin.SIZE + ids.HashBuiltin.x] = sibling + %} + prev_leaf = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.y] + new_leaf = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.y] + + # Make sure the same authentication path is used. + let left_sibling = ap + [left_sibling] = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.x] + [left_sibling] = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.x]; ap++ + + # Compute index - 1. + tempvar index_minus_one = index - 1 + + # Call merkle_update recursively. + [ap] = hash_ptr + 2 * HashBuiltin.SIZE; ap++ # hash_ptr. + [ap] = height - 1; ap++ # height. + [ap] = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.result]; ap++ # prev_leaf. + [ap] = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.result]; ap++ # new_leaf. + + let update_right_index = ap + # Compute (index - 1) / 2. + index_minus_one = [update_right_index] * 2; ap++ # index. + merkle_update(...) # Tail recursion. + return (...) +end diff --git a/src/starkware/cairo/apps/starkex2_0/common/registers.cairo b/src/starkware/cairo/apps/starkex2_0/common/registers.cairo new file mode 100644 index 00000000..acdbd76b --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/common/registers.cairo @@ -0,0 +1,29 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +# Returns the contents of the fp and pc registers of the calling function. +# The pc register's value is the address of the instruction that follows directly after the +# invocation of get_fp_and_pc(). +func get_fp_and_pc() -> (fp_val, pc_val): + # The call instruction itself already places the old fp and the return pc at [fp - 2], [fp - 1]. + # Thus, we can simply return, and the calling function may regard these as the return values + # of this function. + return (...) +end + +# Returns the content of the ap register just before this function was invoked. +func get_ap() -> (ap_val): + # Once get_ap() is invoked, fp points to ap + 2 (since the call instruction placed the old fp + # and pc in memory, advancing ap accordingly). + # Calling dummy_func places fp and pc at [fp], [fp + 1] (respectively), and advances ap by 2. + # Hence, going two cells above we get [fp] = ap + 2, and by subtracting 2 we get the desired ap + # value. + call dummy_func + return (ap_val=[ap - 2] - 2) +end + +func dummy_func(): + return () +end diff --git a/src/starkware/cairo/apps/starkex2_0/dex_constants.cairo b/src/starkware/cairo/apps/starkex2_0/dex_constants.cairo new file mode 100644 index 00000000..49af975c --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/dex_constants.cairo @@ -0,0 +1,36 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +# The hash of an empty vault (i.e., with balance = 0) is defined to be h(h(0,0),0). +const ZERO_VAULT_HASH = 3051532127692517571387022095821932649971160144101372951378323654799587621206 + +# Balance should be in the range [0, 2**63). +const BALANCE_BOUND = %[ 2 ** 63 %] + +# Nonce should be in the range [0, 2**31). +const NONCE_BOUND = %[ 2 ** 31 %] + +# Expiration timestamp should be in the range [0, 2**22). +const EXPIRATION_TIMESTAMP_BOUND = %[ 2 ** 22 %] + +# Order id should be in the range [0, 2**63). +const ORDER_ID_BOUND = %[ 2 ** 63 %] + +# The result of a hash builtin should be in the range [0, 2**251). +const HASH_MESSAGE_BOUND = %[ 2 ** 251 %] + +# The range-check builtin enables verifying that a value is within the range [0, 2**128). +const RANGE_CHECK_BOUND = %[ 2 ** 128 %] + +namespace PackedOrderMsg: + const SETTLEMENT_ORDER_TYPE = 0 + const TRANSFER_ORDER_TYPE = 1 + const CONDITIONAL_TRANSFER_ORDER_TYPE = 2 + # Vault shift in packed order message is 2**31, regardless of the actual vault tree height. + const VAULT_SHIFT = %[ 2 ** 31 %] + const AMOUNT_SHIFT = BALANCE_BOUND + const NONCE_SHIFT = NONCE_BOUND + const EXPIRATION_TIMESTAMP_SHIFT = EXPIRATION_TIMESTAMP_BOUND +end diff --git a/src/starkware/cairo/apps/starkex2_0/dex_context.cairo b/src/starkware/cairo/apps/starkex2_0/dex_context.cairo new file mode 100644 index 00000000..a2b2951b --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/dex_context.cairo @@ -0,0 +1,22 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +from starkware.cairo.apps.starkex2_0.common.registers import get_fp_and_pc + +# A representation of a DEX context struct. + +struct DexContext: + member vault_tree_height = 0 + member order_tree_height = 1 + member global_expiration_timestamp = 2 + const SIZE = 3 +end + +# Returns a pointer to a new DexContext struct. +func make_dex_context(vault_tree_height, order_tree_height, global_expiration_timestamp) -> ( + addr : DexContext*): + let (__fp__, _) = get_fp_and_pc() + return (addr=cast(__fp__ - 2 - DexContext.SIZE, DexContext*)) +end diff --git a/src/starkware/cairo/apps/starkex2_0/execute_batch.cairo b/src/starkware/cairo/apps/starkex2_0/execute_batch.cairo new file mode 100644 index 00000000..f52ed8b2 --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/execute_batch.cairo @@ -0,0 +1,127 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import SignatureBuiltin +from starkware.cairo.apps.starkex2_0.common.dict import DictAccess +from starkware.cairo.apps.starkex2_0.dex_context import DexContext +from starkware.cairo.apps.starkex2_0.execute_false_full_withdrawal import execute_false_full_withdrawal +from starkware.cairo.apps.starkex2_0.execute_modification import ModificationOutput +from starkware.cairo.apps.starkex2_0.execute_modification import execute_modification +from starkware.cairo.apps.starkex2_0.execute_settlement import execute_settlement +from starkware.cairo.apps.starkex2_0.execute_transfer import execute_transfer + +# Executes a batch of transactions (settlements, transfers, modifications). +func execute_batch( + modification_ptr : ModificationOutput*, conditional_transfer_ptr, hash_ptr : HashBuiltin*, + range_check_ptr, ecdsa_ptr : SignatureBuiltin*, vault_dict : DictAccess*, + order_dict : DictAccess*, dex_context_ptr : DexContext*) -> ( + modification_ptr : ModificationOutput*, conditional_transfer_ptr, hash_ptr : HashBuiltin*, + range_check_ptr, ecdsa_ptr : SignatureBuiltin*, vault_dict : DictAccess*, + order_dict : DictAccess*): + # Guess if the first transaction is a settlement. + jmp handle_settlement if [ap] != 0; ap++ + + # Guess if the first transaction is a transfer. + jmp handle_transfer if [ap] != 0; ap++ + + # Guess if the first transaction is a modification. + jmp handle_modification if [ap] != 0; ap++ + + # Otherwise, check that there are no other (undefined) transactions and return. + return ( + modification_ptr=modification_ptr, + conditional_transfer_ptr=conditional_transfer_ptr, + hash_ptr=hash_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + vault_dict=vault_dict, + order_dict=order_dict) + + handle_settlement: + # Call execute_settlement. + let settlement_res = execute_settlement( + hash_ptr=hash_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + vault_dict=vault_dict, + order_dict=order_dict, + dex_context_ptr=dex_context_ptr) + + # Call execute_batch recursively. + execute_batch( + modification_ptr=modification_ptr, + conditional_transfer_ptr=conditional_transfer_ptr, + hash_ptr=settlement_res.hash_ptr, + range_check_ptr=settlement_res.range_check_ptr, + ecdsa_ptr=settlement_res.ecdsa_ptr, + vault_dict=settlement_res.vault_dict, + order_dict=settlement_res.order_dict, + dex_context_ptr=dex_context_ptr) + return (...) + + handle_transfer: + # Call execute_transfer. + let transfer_res = execute_transfer( + hash_ptr=hash_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + conditional_transfer_ptr=conditional_transfer_ptr, + vault_dict=vault_dict, + order_dict=order_dict, + dex_context_ptr=dex_context_ptr) + + # Call execute_batch recursively. + execute_batch( + modification_ptr=modification_ptr, + conditional_transfer_ptr=transfer_res.conditional_transfer_ptr, + hash_ptr=transfer_res.hash_ptr, + range_check_ptr=transfer_res.range_check_ptr, + ecdsa_ptr=transfer_res.ecdsa_ptr, + vault_dict=transfer_res.vault_dict, + order_dict=transfer_res.order_dict, + dex_context_ptr=dex_context_ptr) + return (...) + + handle_modification: + # Guess if the first modification is a false full withdrawal. + jmp handle_false_full_withdrawal if [ap] != 0; ap++ + + # Call execute_modification. + let (range_check_ptr, modification_ptr, vault_dict) = execute_modification( + range_check_ptr=range_check_ptr, + modification_ptr=modification_ptr, + dex_context_ptr=dex_context_ptr, + vault_dict=vault_dict) + + # Call execute_batch recursively. + execute_batch( + modification_ptr=modification_ptr, + conditional_transfer_ptr=conditional_transfer_ptr, + hash_ptr=hash_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + vault_dict=vault_dict, + order_dict=order_dict, + dex_context_ptr=dex_context_ptr) + return (...) + + handle_false_full_withdrawal: + # Call execute_false_full_withdrawal. + let (vault_dict, modification_ptr) = execute_false_full_withdrawal( + modification_ptr=modification_ptr, dex_context_ptr=dex_context_ptr, vault_dict=vault_dict) + + # Call execute_batch recursively. + execute_batch( + modification_ptr=modification_ptr, + conditional_transfer_ptr=conditional_transfer_ptr, + hash_ptr=hash_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + vault_dict=vault_dict, + order_dict=order_dict, + dex_context_ptr=dex_context_ptr) + return (...) +end diff --git a/src/starkware/cairo/apps/starkex2_0/execute_false_full_withdrawal.cairo b/src/starkware/cairo/apps/starkex2_0/execute_false_full_withdrawal.cairo new file mode 100644 index 00000000..09b10cec --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/execute_false_full_withdrawal.cairo @@ -0,0 +1,59 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +from starkware.cairo.apps.starkex2_0.common.dict import DictAccess +from starkware.cairo.apps.starkex2_0.dex_constants import BALANCE_BOUND +from starkware.cairo.apps.starkex2_0.dex_context import DexContext +from starkware.cairo.apps.starkex2_0.execute_modification import ModificationOutput +from starkware.cairo.apps.starkex2_0.vault_update import vault_update_balances + +# Executes a false full withdrawal. +# Validates that the guessed requester_stark_key is not the same as the stark key in the vault +# and writes the requester_stark_key to the program output. +# Assumptions: keys in the vault_dict are range-checked to be < VAULT_SHIFT. +func execute_false_full_withdrawal( + modification_ptr : ModificationOutput*, dex_context_ptr : DexContext*, + vault_dict : DictAccess*) -> ( + vault_dict : DictAccess*, modification_ptr : ModificationOutput*): + let dex_context : DexContext* = dex_context_ptr + let output : ModificationOutput* = modification_ptr + + const FULL_WITHDRAWAL_SHIFT = ModificationOutput.FULL_WITHDRAWAL_SHIFT + const BALANCE_SHIFT = ModificationOutput.BALANCE_SHIFT + + alloc_locals + local stark_key + local balance_before + local token_id + local vault_index + + assert output.token_id = 0 + + # Note that we assume vault_index is range-checked during the merkle_multi_update, + # which will force the full withdrawal bit to be 1. + assert output.action = vault_index * BALANCE_SHIFT + BALANCE_BOUND + FULL_WITHDRAWAL_SHIFT + + # In false full withdrawal balance_before must be equal to balance_after. + vault_update_balances( + balance_before=balance_before, + balance_after=balance_before, + stark_key=stark_key, + token_id=token_id, + vault_index=vault_index, + vault_change_ptr=vault_dict) + + # Guess the requester_stark_key, write it to the output and make sure it's not the same as the + # stark_key. + let requester_stark_key = output.stark_key + tempvar key_diff = requester_stark_key - stark_key + if key_diff == 0: + # Add an unsatisfiable assertion when key_diff == 0. + key_diff = 1 + end + + return ( + vault_dict=vault_dict + DictAccess.SIZE, + modification_ptr=modification_ptr + ModificationOutput.SIZE) +end diff --git a/src/starkware/cairo/apps/starkex2_0/execute_limit_order.cairo b/src/starkware/cairo/apps/starkex2_0/execute_limit_order.cairo new file mode 100644 index 00000000..1f787e06 --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/execute_limit_order.cairo @@ -0,0 +1,145 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import SignatureBuiltin +from starkware.cairo.apps.starkex2_0.common.dict import DictAccess +from starkware.cairo.apps.starkex2_0.common.merkle_update import merkle_update +from starkware.cairo.apps.starkex2_0.dex_constants import BALANCE_BOUND +from starkware.cairo.apps.starkex2_0.dex_constants import EXPIRATION_TIMESTAMP_BOUND +from starkware.cairo.apps.starkex2_0.dex_constants import NONCE_BOUND +from starkware.cairo.apps.starkex2_0.dex_constants import PackedOrderMsg +from starkware.cairo.apps.starkex2_0.dex_context import DexContext +from starkware.cairo.apps.starkex2_0.vault_update import vault_update_diff +from starkware.cairo.apps.starkex2_0.verify_order_signature import verify_order_signature + +# Executes a limit order of a single party. Each settlement will invoke this function twice, once +# per each party. +# A limit order can be described by the following statement: +# "I want to sell a maximum of amount_sell tokens of type token_sell, and in return I expect +# to receive at least amount_buy tokens of type token_buy (relative to the actual number of tokens +# sold)." +# +# The actual amounts that were transferred are amount_sold, amount_bought. +# +# sell_change and buy_change are DictAccess pointers into the vault_dict. +# They are given as two distinct pointers to allow the caller to control the order in which the +# vault updates are applied. +# +# Assumptions: +# * 0 <= amount_sold, amount_bought < BALANCE_BOUND. +# * 0 <= global_expiration_timestamp, and it has not expired yet. +func execute_limit_order( + hash_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + sell_change : DictAccess*, buy_change : DictAccess*, order_dict : DictAccess*, amount_sold, + amount_bought, token_sell, token_buy, dex_context_ptr : DexContext*) -> ( + hash_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + order_dict : DictAccess*): + # Local variables. + alloc_locals + local amount_sell + local amount_buy + local vault_id_sell + local vault_id_buy + local stark_key + local nonce + local order_id + local expiration_timestamp + local prev_fulfilled_amount + local new_fulfilled_amount + + let dex_context : DexContext* = dex_context_ptr + + # Define an inclusive amount bound reference for amount range-checks. + tempvar inclusive_amount_bound = BALANCE_BOUND - 1 + + # Check that 0 <= amount_sell < BALANCE_BOUND. + assert [range_check_ptr] = amount_sell + # Guarantee that amount_sell <= inclusive_amount_bound < BALANCE_BOUND. + assert [range_check_ptr + 1] = inclusive_amount_bound - amount_sell + + # Check that 0 <= amount_buy < BALANCE_BOUND. + assert [range_check_ptr + 2] = amount_buy + # Guarantee that amount_buy <= inclusive_amount_bound < BALANCE_BOUND. + assert [range_check_ptr + 3] = inclusive_amount_bound - amount_buy + + # Check that the party has not sold more than the sell amount limit specified in their order. + new_fulfilled_amount = prev_fulfilled_amount + amount_sold + # Guarantee that new_fulfilled_amount <= amount_sell, which also implies that + # amount_sold <= amount_sell. + assert [range_check_ptr + 4] = amount_sell - new_fulfilled_amount + + # Check that 0 <= nonce < NONCE_BOUND. + tempvar inclusive_nonce_bound = NONCE_BOUND - 1 + assert [range_check_ptr + 5] = nonce + # Guarantee that nonce <= inclusive_nonce_bound < NONCE_BOUND. + assert [range_check_ptr + 6] = inclusive_nonce_bound - nonce + + # Check that the order has not expired yet. + tempvar global_expiration_timestamp = dex_context.global_expiration_timestamp + # Guarantee that global_expiration_timestamp <= expiration_timestamp, which also implies that + # 0 <= expiration_timestamp. + assert [range_check_ptr + 7] = expiration_timestamp - global_expiration_timestamp + + # Check that expiration_timestamp < EXPIRATION_TIMESTAMP_BOUND. + tempvar inclusive_expiration_timestamp_bound = EXPIRATION_TIMESTAMP_BOUND - 1 + # Guarantee that expiration_timestamp <= inclusive_expiration_timestamp_bound < + # EXPIRATION_TIMESTAMP_BOUND. + assert [range_check_ptr + 8] = inclusive_expiration_timestamp_bound - expiration_timestamp + + # Check that the actual ratio (amount_bought / amount_sold) is better than (or equal to) the + # requested ratio (amount_buy / amount_sell) by checking that + # amount_sell * amount_bought >= amount_sold * amount_buy. + assert [range_check_ptr + 9] = amount_sell * amount_bought - amount_sold * amount_buy + + # Update orders dict. + let order_dict_access : DictAccess* = order_dict + order_id = order_dict_access.key + prev_fulfilled_amount = order_dict_access.prev_value + new_fulfilled_amount = order_dict_access.new_value + + # Call vault_update for selling, to update the vault tree with the new balance of the sell + # vault. + let sell_vault_update_ret = vault_update_diff( + range_check_ptr=range_check_ptr + 10, + diff=amount_sold * (-1), + stark_key=stark_key, + token_id=token_sell, + vault_index=vault_id_sell, + vault_change_ptr=sell_change) + + # Call vault_update for buying, to update the vault tree with the new balance of the buy vault. + # range_check_ptr is already in [ap - 1]. + let buy_vault_update_ret = vault_update_diff( + ..., + diff=amount_bought, + stark_key=stark_key, + token_id=token_buy, + vault_index=vault_id_buy, + vault_change_ptr=buy_change) + + let verify_order_signature_ret = verify_order_signature( + hash_ptr=hash_ptr, + range_check_ptr=buy_vault_update_ret.range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + public_key=stark_key, + order_type=PackedOrderMsg.SETTLEMENT_ORDER_TYPE, + vault0=vault_id_sell, + vault1=vault_id_buy, + amount0=amount_sell, + amount1=amount_buy, + token0=token_sell, + token1_or_pub_key=token_buy, + nonce=nonce, + expiration_timestamp=expiration_timestamp, + order_id=order_id, + condition=0) + + return ( + hash_ptr=verify_order_signature_ret.hash_ptr, + range_check_ptr=verify_order_signature_ret.range_check_ptr, + ecdsa_ptr=verify_order_signature_ret.ecdsa_ptr, + order_dict=order_dict + DictAccess.SIZE) +end diff --git a/src/starkware/cairo/apps/starkex2_0/execute_modification.cairo b/src/starkware/cairo/apps/starkex2_0/execute_modification.cairo new file mode 100644 index 00000000..b5667e1c --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/execute_modification.cairo @@ -0,0 +1,99 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +from starkware.cairo.apps.starkex2_0.common.dict import DictAccess +from starkware.cairo.apps.starkex2_0.dex_constants import BALANCE_BOUND +from starkware.cairo.apps.starkex2_0.dex_context import DexContext +from starkware.cairo.apps.starkex2_0.vault_update import vault_update_balances + +# Represents the struct of data written to the program output for each modification. +namespace ModificationOutput: + const BALANCE_SHIFT = %[ 2**64 %] + const VAULT_SHIFT = %[ 2**31 %] + const FULL_WITHDRAWAL_SHIFT = BALANCE_SHIFT * VAULT_SHIFT + + # The stark_key of the changed vault. + member stark_key = 0 + # The token_id of the token which was deposited or withdrawn. + member token_id = 1 + # A packed field which consists of the balances and vault_id. + # The format is as follows: + # +--------------------+------------------+----------------LSB-+ + # | full_withdraw (1b) | vault_idx (31b) | balance_diff (64b) | + # +--------------------+------------------+--------------------+ + # where balance_diff is represented using a 2**63 biased-notation. + member action = 2 + const SIZE = 3 +end + +# Executes a modification (deposit or withdrawal) which changes the balance in a single vault +# and writes the details of that change to the program output, so that the inverse operation +# may be performed by the solidity contract on the on-chain deposit/withdrawal vaults. +func execute_modification( + range_check_ptr, modification_ptr : ModificationOutput*, dex_context_ptr : DexContext*, + vault_dict : DictAccess*) -> ( + range_check_ptr, modification_ptr : ModificationOutput*, vault_dict : DictAccess*): + # Local variables. + alloc_locals + local balance_before + local balance_after + local vault_index + local is_full_withdrawal + + let dex_context : DexContext* = dex_context_ptr + let output : ModificationOutput* = modification_ptr + + # Copy constants to allow overriding them in the tests. + const BALANCE_SHIFT = ModificationOutput.BALANCE_SHIFT + const VAULT_SHIFT = ModificationOutput.VAULT_SHIFT + + # Perform range checks on balance_before, balance_after and vault_index to make sure + # their values are valid, and that they do not overlap in the modification action field. + tempvar inclusive_balance_bound = BALANCE_BOUND - 1 + + # Check that 0 <= balance_before < BALANCE_BOUND. + assert [range_check_ptr] = balance_before + # Guarantee that balance_before <= inclusive_balance_bound < BALANCE_BOUND. + assert [range_check_ptr + 1] = inclusive_balance_bound - balance_before + + # Check that 0 <= balance_after < BALANCE_BOUND. + assert [range_check_ptr + 2] = balance_after + # Guarantee that balance_after <= inclusive_balance_bound < BALANCE_BOUND. + assert [range_check_ptr + 3] = inclusive_balance_bound - balance_after + + # Note: This range-check is redundant as it is also checked in vault_update_balances. + # We keep it here for consistency with the other fields and to avoid the unnecessary dependency + # on the guarantees of vault_update_balances(). + assert [range_check_ptr + 4] = vault_index + # Guarantee that vault_index < VAULT_SHIFT. + assert [range_check_ptr + 5] = (VAULT_SHIFT - 1) - vault_index + + # Assert that is_full_withdrawal is a bit. + is_full_withdrawal = is_full_withdrawal * is_full_withdrawal + + # If is_full_withdrawal is set, balance_after must be 0. + assert is_full_withdrawal * balance_after = 0 + + # balance_before and balance_after were range checked and are guaranteed to be in the range + # [0, BALANCE_BOUND) => diff is in the range (-BALANCE_BOUND, BALANCE_BOUND) + # => biased_diff is in the range [1, 2*BALANCE_BOUND). + tempvar diff = balance_after - balance_before + tempvar biased_diff = diff + BALANCE_BOUND + assert output.action = ((is_full_withdrawal * VAULT_SHIFT) + vault_index) * BALANCE_SHIFT + + biased_diff + + vault_update_balances( + balance_before=balance_before, + balance_after=balance_after, + stark_key=output.stark_key, + token_id=output.token_id, + vault_index=vault_index, + vault_change_ptr=vault_dict) + + return ( + range_check_ptr=range_check_ptr + 6, + modification_ptr=modification_ptr + ModificationOutput.SIZE, + vault_dict=vault_dict + DictAccess.SIZE) +end diff --git a/src/starkware/cairo/apps/starkex2_0/execute_settlement.cairo b/src/starkware/cairo/apps/starkex2_0/execute_settlement.cairo new file mode 100644 index 00000000..a42c471d --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/execute_settlement.cairo @@ -0,0 +1,74 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import SignatureBuiltin +from starkware.cairo.apps.starkex2_0.common.dict import DictAccess +from starkware.cairo.apps.starkex2_0.dex_constants import BALANCE_BOUND +from starkware.cairo.apps.starkex2_0.dex_context import DexContext +from starkware.cairo.apps.starkex2_0.execute_limit_order import execute_limit_order + +# Executes a settlement between two parties, where each party signed an appropriate limit order +# and those orders match. +func execute_settlement( + hash_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + vault_dict : DictAccess*, order_dict : DictAccess*, dex_context_ptr : DexContext*) -> ( + hash_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + vault_dict : DictAccess*, order_dict : DictAccess*): + # Local variables. + alloc_locals + local party_a_sold + local party_b_sold + local token_a # Token sold by party a, and bought by party b. + local token_b # Token sold by party b, and bought by party a. + + # Define an inclusive amount bound reference for amount range-checks. + tempvar inclusive_amount_bound = BALANCE_BOUND - 1 + + # Check that 0 <= party_a_sold < BALANCE_BOUND. + assert [range_check_ptr] = party_a_sold + # Guarantee that party_a_sold <= inclusive_amount_bound < BALANCE_BOUND. + assert [range_check_ptr + 1] = inclusive_amount_bound - party_a_sold + + # Check that 0 <= party_b_sold < BALANCE_BOUND. + assert [range_check_ptr + 2] = party_b_sold + # Guarantee that party_b_sold <= inclusive_amount_bound < BALANCE_BOUND. + assert [range_check_ptr + 3] = inclusive_amount_bound - party_b_sold + + # Call execute_limit_order for party a: + let return0 = execute_limit_order( + hash_ptr=hash_ptr, + range_check_ptr=range_check_ptr + 4, + ecdsa_ptr=ecdsa_ptr, + sell_change=vault_dict, + buy_change=vault_dict + 3 * DictAccess.SIZE, + order_dict=order_dict, + amount_sold=party_a_sold, + amount_bought=party_b_sold, + token_sell=token_a, + token_buy=token_b, + dex_context_ptr=dex_context_ptr) + + # Call execute_limit_order for party b. + let return1 = execute_limit_order( + hash_ptr=return0.hash_ptr, + range_check_ptr=return0.range_check_ptr, + ecdsa_ptr=return0.ecdsa_ptr, + sell_change=vault_dict + 2 * DictAccess.SIZE, + buy_change=vault_dict + 1 * DictAccess.SIZE, + order_dict=return0.order_dict, + amount_sold=party_b_sold, + amount_bought=party_a_sold, + token_sell=token_b, + token_buy=token_a, + dex_context_ptr=dex_context_ptr) + + return ( + hash_ptr=return1.hash_ptr, + range_check_ptr=return1.range_check_ptr, + ecdsa_ptr=return1.ecdsa_ptr, + vault_dict=vault_dict + 4 * DictAccess.SIZE, + order_dict=return1.order_dict) +end diff --git a/src/starkware/cairo/apps/starkex2_0/execute_transfer.cairo b/src/starkware/cairo/apps/starkex2_0/execute_transfer.cairo new file mode 100644 index 00000000..ff42e4d7 --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/execute_transfer.cairo @@ -0,0 +1,137 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import SignatureBuiltin +from starkware.cairo.apps.starkex2_0.common.dict import DictAccess +from starkware.cairo.apps.starkex2_0.common.merkle_update import merkle_update +from starkware.cairo.apps.starkex2_0.dex_constants import BALANCE_BOUND +from starkware.cairo.apps.starkex2_0.dex_constants import EXPIRATION_TIMESTAMP_BOUND +from starkware.cairo.apps.starkex2_0.dex_constants import NONCE_BOUND +from starkware.cairo.apps.starkex2_0.dex_constants import PackedOrderMsg +from starkware.cairo.apps.starkex2_0.dex_context import DexContext +from starkware.cairo.apps.starkex2_0.vault_update import vault_update_diff +from starkware.cairo.apps.starkex2_0.verify_order_signature import verify_order_signature + +# Executes a (conditional) transfer order. +# A (conditional) transfer order can be described by the following statement: +# "I want to transfer exactly 'amount' tokens of type 'token' to user 'receiver_stark_key' +# in vault 'target_vault' (only if the specified 'condition' is satisfied)". +# +# Assumptions: +# * 0 <= global_expiration_timestamp, and it has not expired yet. +func execute_transfer( + hash_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + conditional_transfer_ptr, vault_dict : DictAccess*, order_dict : DictAccess*, + dex_context_ptr : DexContext*) -> ( + hash_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + conditional_transfer_ptr, vault_dict : DictAccess*, order_dict : DictAccess*): + # Local variables. + alloc_locals + local amount + local token_id + local sender_vault_id + local receiver_vault_id + local sender_stark_key + local receiver_stark_key + local nonce + local order_id + local expiration_timestamp + local order_type + local condition + local new_conditional_transfer_pointer + + let dex_context : DexContext* = dex_context_ptr + + # Check that 0 <= amount < BALANCE_BOUND. + tempvar inclusive_amount_bound = BALANCE_BOUND - 1 + assert [range_check_ptr] = amount + # Guarantee that amount <= inclusive_amount_bound < BALANCE_BOUND. + assert [range_check_ptr + 1] = inclusive_amount_bound - amount + + # Check that 0 <= nonce < NONCE_BOUND. + tempvar inclusive_nonce_bound = NONCE_BOUND - 1 + assert [range_check_ptr + 2] = nonce + # Guarantee that nonce <= inclusive_nonce_bound < NONCE_BOUND. + assert [range_check_ptr + 3] = inclusive_nonce_bound - nonce + + # Check that the order has not expired yet. + tempvar global_expiration_timestamp = dex_context.global_expiration_timestamp + # Guarantee that global_expiration_timestamp <= expiration_timestamp, which also implies that + # 0 <= expiration_timestamp. + assert [range_check_ptr + 4] = expiration_timestamp - global_expiration_timestamp + + # Check that expiration_timestamp < EXPIRATION_TIMESTAMP_BOUND. + tempvar inclusive_expiration_timestamp_bound = EXPIRATION_TIMESTAMP_BOUND - 1 + # Guarantee that expiration_timestamp <= inclusive_expiration_timestamp_bound < + # EXPIRATION_TIMESTAMP_BOUND. + assert [range_check_ptr + 5] = inclusive_expiration_timestamp_bound - expiration_timestamp + + # Call vault_update for the sender. + let sender_vault_update_ret = vault_update_diff( + range_check_ptr=range_check_ptr + 6, + diff=amount * (-1), + stark_key=sender_stark_key, + token_id=token_id, + vault_index=sender_vault_id, + vault_change_ptr=vault_dict) + + # Call vault_update for the receiver. + let receiver_vault_update_ret = vault_update_diff( + range_check_ptr=sender_vault_update_ret.range_check_ptr, + diff=amount, + stark_key=receiver_stark_key, + token_id=token_id, + vault_index=receiver_vault_id, + vault_change_ptr=vault_dict + DictAccess.SIZE) + + local range_check_ptr_after_vault_update = receiver_vault_update_ret.range_check_ptr + + # Assert that the correct order_type is given for transfer (condition == 0) and + # conditional transfer (condition != 0). + + if condition != 0: + # Conditional transfer. + order_type = PackedOrderMsg.CONDITIONAL_TRANSFER_ORDER_TYPE + [conditional_transfer_ptr] = condition + new_conditional_transfer_pointer = conditional_transfer_ptr + 1 + else: + # Normal transfer. + order_type = PackedOrderMsg.TRANSFER_ORDER_TYPE + new_conditional_transfer_pointer = conditional_transfer_ptr + end + + let verify_order_signature_ret = verify_order_signature( + hash_ptr=hash_ptr, + range_check_ptr=range_check_ptr_after_vault_update, + ecdsa_ptr=ecdsa_ptr, + public_key=sender_stark_key, + order_type=order_type, + vault0=sender_vault_id, + vault1=receiver_vault_id, + amount0=amount, + amount1=0, + token0=token_id, + token1_or_pub_key=receiver_stark_key, + nonce=nonce, + expiration_timestamp=expiration_timestamp, + order_id=order_id, + condition=condition) + + # Update orders dict. + let order_dict_access : DictAccess* = order_dict + order_id = order_dict_access.key + tempvar zero = 0 + zero = order_dict_access.prev_value + amount = order_dict_access.new_value + + return ( + hash_ptr=verify_order_signature_ret.hash_ptr, + range_check_ptr=verify_order_signature_ret.range_check_ptr, + ecdsa_ptr=verify_order_signature_ret.ecdsa_ptr, + conditional_transfer_ptr=new_conditional_transfer_pointer, + vault_dict=vault_dict + 2 * DictAccess.SIZE, + order_dict=order_dict + DictAccess.SIZE) +end diff --git a/src/starkware/cairo/apps/starkex2_0/hash_vault_ptr_dict.cairo b/src/starkware/cairo/apps/starkex2_0/hash_vault_ptr_dict.cairo new file mode 100644 index 00000000..0a7f38e9 --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/hash_vault_ptr_dict.cairo @@ -0,0 +1,60 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin +from starkware.cairo.apps.starkex2_0.common.dict import DictAccess +from starkware.cairo.apps.starkex2_0.vault_update import compute_vault_hash +from starkware.cairo.apps.starkex2_0.vault_update import VaultState + +# Gets a single pointer to a vault state and outputs the hash of that vault. +func hash_vault_state_ptr(hash_ptr : HashBuiltin*, vault_state_ptr : VaultState*) -> ( + vault_hash, hash_ptr : HashBuiltin*): + let hash_builtin : HashBuiltin* = hash_ptr + let vault_state : VaultState* = vault_state_ptr + + assert hash_builtin.x = vault_state.stark_key + assert hash_builtin.y = vault_state.token_id + + # Compute new hash. + compute_vault_hash( + hash_ptr=hash_ptr + HashBuiltin.SIZE, + key_token_hash=hash_builtin.result, + amount=vault_state.balance) + return (...) +end + +# Takes a vault_ptr_dict with pointers to vault states and writes a new vault_hash_dict with +# hashed vaults instead of pointers. +# The size of the vault_hash_dict is the same as the original dict and the DictAccess keys are +# copied as is. +func hash_vault_ptr_dict( + hash_ptr : HashBuiltin*, vault_ptr_dict : DictAccess*, n_entries, + vault_hash_dict : DictAccess*) -> (hash_ptr : HashBuiltin*): + if n_entries == 0: + return (hash_ptr=hash_ptr) + end + + let hash_builtin : HashBuiltin* = hash_ptr + let vault_access : DictAccess* = vault_ptr_dict + let hashed_vault_access : DictAccess* = vault_hash_dict + + # Copy the key. + assert hashed_vault_access.key = vault_access.key + let prev_hash_res = hash_vault_state_ptr( + hash_ptr=hash_ptr, vault_state_ptr=cast(vault_access.prev_value, VaultState*)) + hashed_vault_access.prev_value = prev_hash_res.vault_hash + + let new_hash_res = hash_vault_state_ptr( + hash_ptr=prev_hash_res.hash_ptr, vault_state_ptr=cast(vault_access.new_value, VaultState*)) + hashed_vault_access.new_value = new_hash_res.vault_hash + + # Tail call. + hash_vault_ptr_dict( + hash_ptr=new_hash_res.hash_ptr, + vault_ptr_dict=vault_ptr_dict + DictAccess.SIZE, + n_entries=n_entries - 1, + vault_hash_dict=vault_hash_dict + DictAccess.SIZE) + return (...) +end diff --git a/src/starkware/cairo/apps/starkex2_0/main.cairo b/src/starkware/cairo/apps/starkex2_0/main.cairo new file mode 100644 index 00000000..8d35f077 --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/main.cairo @@ -0,0 +1,137 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +%builtins output pedersen range_check ecdsa + +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import SignatureBuiltin +from starkware.cairo.apps.starkex2_0.common.dict import DictAccess +from starkware.cairo.apps.starkex2_0.common.dict import squash_dict +from starkware.cairo.apps.starkex2_0.common.merkle_multi_update import merkle_multi_update +from starkware.cairo.apps.starkex2_0.dex_context import make_dex_context +from starkware.cairo.apps.starkex2_0.execute_batch import execute_batch +from starkware.cairo.apps.starkex2_0.execute_modification import ModificationOutput +from starkware.cairo.apps.starkex2_0.hash_vault_ptr_dict import hash_vault_ptr_dict +from starkware.cairo.apps.starkex2_0.vault_update import VaultState + +struct DexOutput: + member initial_vault_root = 0 + member final_vault_root = 1 + member initial_order_root = 2 + member final_order_root = 3 + member global_expiration_timestamp = 4 + member vault_tree_height = 5 + member order_tree_height = 6 + member n_modifications = 7 + member n_conditional_transfers = 8 + const SIZE = 9 +end + +func main( + output_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr, + ecdsa_ptr : SignatureBuiltin*) -> ( + output_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr, + ecdsa_ptr : SignatureBuiltin*): + alloc_locals + # Create the globals struct. + let dex_output = cast(output_ptr, DexOutput*) + let (dex_context_ptr) = make_dex_context( + vault_tree_height=dex_output.vault_tree_height, + order_tree_height=dex_output.order_tree_height, + global_expiration_timestamp=dex_output.global_expiration_timestamp) + + local vault_dict : DictAccess* + local order_dict : DictAccess* + local conditional_transfer_ptr + # Call execute_batch. + # Advance output_ptr by DexOutput.SIZE, since DexOutput appears before other stuff. + let executed_batch = execute_batch( + modification_ptr=cast(output_ptr + DexOutput.SIZE, ModificationOutput*), + conditional_transfer_ptr=conditional_transfer_ptr, + hash_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + vault_dict=vault_dict, + order_dict=order_dict, + dex_context_ptr=dex_context_ptr) + + # Assert conditional transfer data starts where modification data ends. + conditional_transfer_ptr = executed_batch.modification_ptr + + # Store conditional transfer end pointer. + local conditional_transfer_end_ptr : felt* = executed_batch.conditional_transfer_ptr + + # Assert that the number of modifications that appear in the output is correct. + assert dex_output.n_modifications = ( + cast(conditional_transfer_ptr, felt) - (cast(output_ptr, felt) + DexOutput.SIZE)) / + ModificationOutput.SIZE + + # Assert that the number of conditional transfers that appear in the output is correct. + assert dex_output.n_conditional_transfers = ( + conditional_transfer_end_ptr - conditional_transfer_ptr) + + # Store builtin pointers. + local hash_ptr_after_execute_batch : HashBuiltin* = executed_batch.hash_ptr + local ecdsa_ptr_after_execute_batch : SignatureBuiltin* = executed_batch.ecdsa_ptr + local order_dict_end : DictAccess* = executed_batch.order_dict + + # Check that the vault and order accesses recorded in vault_dict and dict_vault are + # valid lists of dict accesses and squash them to obtain squashed dicts + # (squashed_vault_dict and squashed_order_dict) with one entry per key + # (value before and value after) which summarizes all the accesses to that key. + + # Squash the vault_dict. + local squashed_vault_dict : DictAccess* + let (range_check_ptr, squash_vault_dict_ret) = squash_dict( + range_check_ptr=executed_batch.range_check_ptr, + dict_accesses=vault_dict, + dict_accesses_end=executed_batch.vault_dict, + squashed_dict=squashed_vault_dict) + local squashed_vault_dict_segment_size = squash_vault_dict_ret - squashed_vault_dict + + # Squash the order_dict. + local squashed_order_dict : DictAccess* + let (range_check_ptr, squash_order_dict_ret) = squash_dict( + range_check_ptr=range_check_ptr, + dict_accesses=order_dict, + dict_accesses_end=order_dict_end, + squashed_dict=squashed_order_dict) + local squashed_order_dict_segment_size = squash_order_dict_ret - squashed_order_dict + local range_check_ptr_after_squash_order_dict = range_check_ptr + + # The squashed_vault_dict holds pointers to vault states instead of vault tree leaf values. + # Call hash_vault_ptr_dict to obtain a new dict that can be passed to merkle_multi_update. + local hashed_vault_dict : DictAccess* + let (hash_vault_dict_ptr) = hash_vault_ptr_dict( + hash_ptr=hash_ptr_after_execute_batch, + vault_ptr_dict=squashed_vault_dict, + n_entries=squashed_vault_dict_segment_size / DictAccess.SIZE, + vault_hash_dict=hashed_vault_dict) + + # Verify hashed_vault_dict consistency with the vault merkle root. + let (vault_merkle_multi_update_ptr) = merkle_multi_update( + hash_ptr=hash_vault_dict_ptr, + update_ptr=hashed_vault_dict, + n_updates=squashed_vault_dict_segment_size / DictAccess.SIZE, + height=dex_output.vault_tree_height, + prev_root=dex_output.initial_vault_root, + new_root=dex_output.final_vault_root) + + # Verify squashed_order_dict consistency with the order merkle root. + let (order_merkle_multi_update_ptr) = merkle_multi_update( + hash_ptr=vault_merkle_multi_update_ptr, + update_ptr=squashed_order_dict, + n_updates=squashed_order_dict_segment_size / DictAccess.SIZE, + height=dex_output.order_tree_height, + prev_root=dex_output.initial_order_root, + new_root=dex_output.final_order_root) + + # Return updated pointers. + return ( + output_ptr=conditional_transfer_end_ptr, + pedersen_ptr=order_merkle_multi_update_ptr, + range_check_ptr=range_check_ptr_after_squash_order_dict, + ecdsa_ptr=ecdsa_ptr_after_execute_batch) +end diff --git a/src/starkware/cairo/apps/starkex2_0/starkex2_0_program_test.py b/src/starkware/cairo/apps/starkex2_0/starkex2_0_program_test.py new file mode 100644 index 00000000..a6ed0c48 --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/starkex2_0_program_test.py @@ -0,0 +1,40 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +import os +import subprocess +import tempfile + +from starkware.python.utils import get_build_dir_path + + +def test_program_hash(): + """ + Tests that the hash of the compiled Cairo program is identical to the one used in the + StarkEx2.0 system. + """ + DIR = os.path.dirname(__file__) + CAIRO_PATH = os.path.join(DIR, '../../../..') + PROGRAM_MAIN_FILE = os.path.join(DIR, 'main.cairo') + CAIRO_COMPILE_EXE = get_build_dir_path('src/starkware/cairo/lang/compiler/cairo_compile_exe') + CAIRO_HASH_PROGRAM_EXE = get_build_dir_path( + 'src/starkware/cairo/bootloader/cairo_hash_program_exe') + + with tempfile.NamedTemporaryFile() as compiled_program: + # Compile the program. + subprocess.check_call([ + f'{CAIRO_COMPILE_EXE}', + PROGRAM_MAIN_FILE, + f'--output={compiled_program.name}', + f'--cairo_path={CAIRO_PATH}', + ]) + program_hash = subprocess.check_output([ + f'{CAIRO_HASH_PROGRAM_EXE}', + f'--program={compiled_program.name}', + ]).decode('ascii').strip() + + # NOTE: The following is the hash of the deployed program in the StarkEx2.0 system. + # It should not be modified. + assert program_hash == '0x15bd9af059b37335cf934461ce167400eec0ef18605193a25fc4bc6f661984a' diff --git a/src/starkware/cairo/apps/starkex2_0/vault_update.cairo b/src/starkware/cairo/apps/starkex2_0/vault_update.cairo new file mode 100644 index 00000000..63332f23 --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/vault_update.cairo @@ -0,0 +1,98 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin +from starkware.cairo.apps.starkex2_0.common.dict import DictAccess +from starkware.cairo.apps.starkex2_0.common.merkle_update import merkle_update +from starkware.cairo.apps.starkex2_0.dex_constants import BALANCE_BOUND +from starkware.cairo.apps.starkex2_0.dex_constants import ZERO_VAULT_HASH + +struct VaultState: + member stark_key = 0 + member token_id = 1 + member balance = 2 + const SIZE = 3 +end + +# Retrieves a pointer to a VaultState with the corresponding vault. +# Returns an empty vault if balance == 0 (stark_key and token_id are ignored). +func get_vault_state(stark_key, token_id, balance) -> (vault_state_ptr : VaultState*): + local vault_state_ptr : VaultState* + + # Allocate 1 slot for our local which is also the return value. + vault_state_ptr.balance = balance; ap++ + static_assert SIZEOF_LOCALS == 1 + + if balance == 0: + # Balance is 0 here, use it for initialization. + let zero = balance + vault_state_ptr.stark_key = zero + vault_state_ptr.token_id = zero + return (...) + end + + vault_state_ptr.stark_key = stark_key + vault_state_ptr.token_id = token_id + return (...) +end + +# Computes the hash h(key_token_hash, amount), where key_token_hash := h(stark_key, token_id). +func compute_vault_hash(hash_ptr : HashBuiltin*, key_token_hash, amount) -> ( + vault_hash, hash_ptr : HashBuiltin*): + if amount == 0: + return (vault_hash=ZERO_VAULT_HASH, hash_ptr=hash_ptr) + end + + key_token_hash = hash_ptr.x + amount = hash_ptr.y + return (vault_hash=hash_ptr.result, hash_ptr=hash_ptr + HashBuiltin.SIZE) +end + +# Updates the balance in the vault (leaf in the vault tree) corresponding to vault_index, +# by writing the change to vault_change_ptr. +# May also by used to verify the values in a certain vault. +func vault_update_balances( + balance_before, balance_after, stark_key, token_id, vault_index, + vault_change_ptr : DictAccess*): + let vault_access : DictAccess* = vault_change_ptr + vault_access.key = vault_index + let (prev_vault_state_ptr) = get_vault_state( + stark_key=stark_key, token_id=token_id, balance=balance_before) + vault_access.prev_value = prev_vault_state_ptr + let (new_vault_state_ptr) = get_vault_state( + stark_key=stark_key, token_id=token_id, balance=balance_after) + vault_access.new_value = new_vault_state_ptr + return () +end + +# Similar to vault_update_balances, except that the expected difference +# (balance_after - balance_before) is given and a range-check is performed on balance_after. +func vault_update_diff( + range_check_ptr, diff, stark_key, token_id, vault_index, + vault_change_ptr : DictAccess*) -> (range_check_ptr): + # Local variables. + alloc_locals + local balance_before + local balance_after + + balance_after = balance_before + diff + + # Check that 0 <= balance_after < BALANCE_BOUND. + assert [range_check_ptr] = balance_after + # Apply the range check builtin on (BALANCE_BOUND - 1 - balance_after), which guarantees that + # balance_after < BALANCE_BOUND. + assert [range_check_ptr + 1] = (BALANCE_BOUND - 1) - balance_after + + # Call vault_update_balances. + vault_update_balances( + balance_before=balance_before, + balance_after=balance_after, + stark_key=stark_key, + token_id=token_id, + vault_index=vault_index, + vault_change_ptr=vault_change_ptr) + + return (range_check_ptr=range_check_ptr + 2) +end diff --git a/src/starkware/cairo/apps/starkex2_0/verify_order_id.cairo b/src/starkware/cairo/apps/starkex2_0/verify_order_id.cairo new file mode 100644 index 00000000..b1fc8b12 --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/verify_order_id.cairo @@ -0,0 +1,50 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +from starkware.cairo.apps.starkex2_0.dex_constants import HASH_MESSAGE_BOUND as DEX_HASH_MESSAGE_BOUND +from starkware.cairo.apps.starkex2_0.dex_constants import ORDER_ID_BOUND as DEX_ORDER_ID_BOUND +from starkware.cairo.apps.starkex2_0.dex_constants import RANGE_CHECK_BOUND as DEX_RANGE_CHECK_BOUND + +# Verifies that the given order_id complies with the order data, encoded in the message_hash. +# The order_id is represented by the 63 most significant bits of the message_hash. +# +# Assumptions: +# * 0 <= order_id < ORDER_ID_BOUND. +func verify_order_id(range_check_ptr, message_hash, order_id) -> (range_check_ptr): + # Copy constants to allow overriding them in the tests. + const HASH_MESSAGE_BOUND = DEX_HASH_MESSAGE_BOUND + const ORDER_ID_BOUND = DEX_ORDER_ID_BOUND + const RANGE_CHECK_BOUND = DEX_RANGE_CHECK_BOUND + + # The 251-bit message_hash can be viewed as a packing of three fields: + # +----------------+--------------------+----------------LSB-+ + # | order_id (63b) | middle_field (60b) | right_field (128b) | + # +----------------+--------------------+--------------------+ + # . + const ORDER_ID_SHIFT = HASH_MESSAGE_BOUND / ORDER_ID_BOUND + const MIDDLE_FIELD_BOUND = ORDER_ID_SHIFT / RANGE_CHECK_BOUND + + # Local variables. + alloc_locals + local middle_field + local right_field + + # Verify that the message_hash definition holds, i.e., that: + # message_hash = ORDER_ID_SHIFT * order_id + RANGE_CHECK_BOUND * middle_field + right_field. + tempvar shifted_middle_field = middle_field * RANGE_CHECK_BOUND + tempvar packed_right_fields = shifted_middle_field + right_field + tempvar shifted_order_id = order_id * ORDER_ID_SHIFT + message_hash = shifted_order_id + packed_right_fields + + # Verify the message_hash structure (i.e., the size of each field), to ensure unique unpacking. + # Note that the size of order_id is verified by performing merkle_update on the order tree. + # Check that 0 <= right_field < RANGE_CHECK_BOUND. + assert [range_check_ptr] = right_field + # Check that 0 <= middle_field < MIDDLE_FIELD_BOUND. + assert [range_check_ptr + 1] = middle_field + assert [range_check_ptr + 2] = (MIDDLE_FIELD_BOUND - 1) - middle_field + + return (range_check_ptr=range_check_ptr + 3) +end diff --git a/src/starkware/cairo/apps/starkex2_0/verify_order_signature.cairo b/src/starkware/cairo/apps/starkex2_0/verify_order_signature.cairo new file mode 100644 index 00000000..522a1cd5 --- /dev/null +++ b/src/starkware/cairo/apps/starkex2_0/verify_order_signature.cairo @@ -0,0 +1,108 @@ +# *********************************************************************** +# * This code is licensed under the Cairo Program License. * +# * The license can be found in: licenses/CairoProgramLicense.txt * +# *********************************************************************** + +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import SignatureBuiltin +from starkware.cairo.apps.starkex2_0.dex_constants import PackedOrderMsg +from starkware.cairo.apps.starkex2_0.verify_order_id import verify_order_id + +# Computes partial_msg_hash for the signed message of transfer and conditional transfer. +# +# If the order is a transfer (condition == 0), returns temp_partial_msg_hash. If the order is a +# conditional transfer (condition != 0), returns hash(temp_partial_msg_hash, condition). +# The returnd value is the first argument to the hash function that computes the signed message, +# i.e. - partial_msg_hash in hash(partial_msg_hash, packed_msg). +# See the documentation of verify_order_signature for more details. +func add_optional_condition_hash(temp_partial_msg_hash, condition, hash_ptr : HashBuiltin*) -> ( + partial_msg_hash, hash_ptr : HashBuiltin*): + if condition == 0: + return (partial_msg_hash=temp_partial_msg_hash, hash_ptr=hash_ptr) + end + + let partial_msg_hash : HashBuiltin* = hash_ptr + partial_msg_hash.x = temp_partial_msg_hash + partial_msg_hash.y = condition + return (partial_msg_hash=partial_msg_hash.result, hash_ptr=hash_ptr + HashBuiltin.SIZE) +end + +# Verifies that the order was signed and that the order_id complies with the order data. +# +# The format of the signed message is as follows: hash(partial_msg_hash, packed_msg). +# packed_msg is a packed field which consists of the part of the order data that is not included in +# the partial_msg_hash, with the following structure: +# +-MSB-------------+--------------+--------------+---------------+---------------+ +# | order_type (4b) | vault0 (31b) | vault1 (31b) | amount0 (63b) | amount1 (63b) | .... +# +-----------------+--------------+--------------+---------------+---------------+ +# +# +-------------+------------------------LSB-+ +# | nonce (31b) | expiration_timestamp (22b) | +# +-------------+----------------------------+ +# +# In case of a settlement (order_type = 0): +# partial_msg_hash := hash(token0, token1), and parameter names with '0' suffix represent sell data, +# while the ones with '1' suffix represent buy data. +# +# In case of a transfer (order_type = 1): +# partial_msg_hash := hash(token0, receiver_public_key), and parameter names with '0' suffix +# represent sender data, while the ones with '1' suffix represent receiver data (and amount1 = 0). +# +# In case of a conditional transfer (order_type = 2): +# partial_msg_hash := hash(hash(token0, receiver_public_key), condition). +# +# Assumptions: +# * order_type = 0, 1 or 2. +# * 0 <= vault0, vault1 < VAULT_SHIFT. +# * 0 <= amount0, amount1 < AMOUNT_SHIFT. +# * 0 <= nonce < NONCE_SHIFT. +# * 0 <= expiration_timestamp < EXPIRATION_TIMESTAMP_SHIFT. +func verify_order_signature( + hash_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, public_key, + order_type, vault0, vault1, amount0, amount1, token0, token1_or_pub_key, nonce, + expiration_timestamp, order_id, condition) -> ( + range_check_ptr, hash_ptr : HashBuiltin*, ecdsa_ptr : SignatureBuiltin*): + alloc_locals + local packed_msg + + # Compute packed order message. + assert packed_msg = ((((((order_type * + PackedOrderMsg.VAULT_SHIFT + vault0) * + PackedOrderMsg.VAULT_SHIFT + vault1) * + PackedOrderMsg.AMOUNT_SHIFT + amount0) * + PackedOrderMsg.AMOUNT_SHIFT + amount1) * + PackedOrderMsg.NONCE_SHIFT + nonce) * + PackedOrderMsg.EXPIRATION_TIMESTAMP_SHIFT + expiration_timestamp) + + # Compute partial_msg_hash. + let temp_partial_msg_hash : HashBuiltin* = hash_ptr + temp_partial_msg_hash.x = token0 + temp_partial_msg_hash.y = token1_or_pub_key + + # Call add_optional_condition_hash. + let (partial_msg_hash, hash_ptr) = add_optional_condition_hash( + temp_partial_msg_hash=temp_partial_msg_hash.result, + condition=condition, + hash_ptr=hash_ptr + HashBuiltin.SIZE) + + # Compute the message to sign on. + local full_msg_hash : HashBuiltin* = hash_ptr + full_msg_hash.x = partial_msg_hash + full_msg_hash.y = packed_msg + + # Verify signature. + let signature : SignatureBuiltin* = ecdsa_ptr + signature.pub_key = public_key + tempvar full_msg_hash_result = full_msg_hash.result + signature.message = full_msg_hash_result + + # Call verify_order_id. + verify_order_id( + range_check_ptr=range_check_ptr, message_hash=full_msg_hash_result, order_id=order_id) + + # range_check_ptr is already in [ap - 1]. + return ( + ..., + hash_ptr=full_msg_hash + HashBuiltin.SIZE, + ecdsa_ptr=ecdsa_ptr + SignatureBuiltin.SIZE) +end diff --git a/src/starkware/cairo/bootloader/CMakeLists.txt b/src/starkware/cairo/bootloader/CMakeLists.txt new file mode 100644 index 00000000..2ab3a0e9 --- /dev/null +++ b/src/starkware/cairo/bootloader/CMakeLists.txt @@ -0,0 +1,35 @@ +python_lib(cairo_hash_program_lib + PREFIX starkware/cairo/bootloader + + FILES + hash_program.py + + LIBS + cairo_common_lib + cairo_compile_lib + cairo_vm_crypto_lib +) + +python_venv(cairo_hash_program_venv + PYTHON python3.7 + LIBS + cairo_hash_program_lib +) + +python_exe(cairo_hash_program_exe + VENV cairo_hash_program_venv + MODULE starkware.cairo.bootloader.hash_program +) + +full_python_test(cairo_hash_program_test + PREFIX starkware/cairo/bootloader + PYTHON python3.7 + TESTED_MODULES starkware/cairo/bootloader + + FILES + hash_program_test.py + + LIBS + cairo_hash_program_lib + pip_pytest +) diff --git a/src/starkware/cairo/bootloader/hash_program.py b/src/starkware/cairo/bootloader/hash_program.py new file mode 100644 index 00000000..357aedf1 --- /dev/null +++ b/src/starkware/cairo/bootloader/hash_program.py @@ -0,0 +1,38 @@ +import argparse +import json + +from starkware.cairo.common.hash_chain import compute_hash_chain +from starkware.cairo.lang.compiler.program import Program, ProgramBase +from starkware.cairo.lang.vm.crypto import get_crypto_lib_context_manager + + +def compute_program_hash_chain(program: ProgramBase, bootloader_version=0): + """ + Computes a hash chain over a program, including the length of the data chain. + """ + builtin_list = [int.from_bytes(builtin.encode('ascii'), 'big') for builtin in program.builtins] + # The program header below is missing the data length, which is later added to the data_chain. + program_header = [bootloader_version, program.main, len(program.builtins)] + builtin_list + data_chain = program_header + program.data + + return compute_hash_chain([len(data_chain)] + data_chain) + + +def main(): + parser = argparse.ArgumentParser( + description='A tool to compute the hash of a cairo program') + parser.add_argument( + '--program', type=argparse.FileType('r'), required=True, + help='The name of the program json file.') + parser.add_argument( + '--flavor', type=str, default='Release', choices=['Debug', 'Release', 'RelWithDebInfo'], + help='Build flavor') + args = parser.parse_args() + + with get_crypto_lib_context_manager(args.flavor): + program = Program.Schema().load(json.load(args.program)) + print(hex(compute_program_hash_chain(program))) + + +if __name__ == '__main__': + main() diff --git a/src/starkware/cairo/bootloader/hash_program_test.py b/src/starkware/cairo/bootloader/hash_program_test.py new file mode 100644 index 00000000..7201fd1c --- /dev/null +++ b/src/starkware/cairo/bootloader/hash_program_test.py @@ -0,0 +1,8 @@ +from starkware.cairo.common.hash_chain import compute_hash_chain +from starkware.cairo.lang.vm.crypto import pedersen_hash + + +def test_compute_hash_chain(): + data = [1, 2, 3] + res = compute_hash_chain(data) + assert res == pedersen_hash(1, pedersen_hash(2, 3)) diff --git a/src/starkware/cairo/common/CMakeLists.txt b/src/starkware/cairo/common/CMakeLists.txt new file mode 100644 index 00000000..20bb485e --- /dev/null +++ b/src/starkware/cairo/common/CMakeLists.txt @@ -0,0 +1,25 @@ +python_lib(cairo_common_lib + PREFIX starkware/cairo/common + FILES + alloc.cairo + cairo_builtins.cairo + dict.cairo + dict.py + find_element.cairo + hash.cairo + hash_chain.cairo + hash_chain.py + hash_state.cairo + math.cairo + math_utils.py + memcpy.cairo + merkle_multi_update.cairo + merkle_update.cairo + registers.cairo + serialize.cairo + signature.cairo + ${CAIRO_COMMON_LIB_ADDITIONAL_FILES} + + LIBS + ${CAIRO_COMMON_LIB_ADDITIONAL_LIBS} +) diff --git a/src/starkware/cairo/common/alloc.cairo b/src/starkware/cairo/common/alloc.cairo new file mode 100644 index 00000000..b6e85f63 --- /dev/null +++ b/src/starkware/cairo/common/alloc.cairo @@ -0,0 +1,6 @@ +# Allocates a new memory segment. +func alloc() -> (ptr): + %{ memory[ap] = segments.add() %} + ap += 1 + return (...) +end diff --git a/src/starkware/cairo/common/cairo_builtins.cairo b/src/starkware/cairo/common/cairo_builtins.cairo new file mode 100644 index 00000000..9113b6ee --- /dev/null +++ b/src/starkware/cairo/common/cairo_builtins.cairo @@ -0,0 +1,22 @@ +# A representation of a HashBuiltin struct, specifying the hash builtin memory structure. +struct HashBuiltin: + member x = 0 + member y = 1 + member result = 2 + const SIZE = 3 +end + +# A representation of a SignatureBuiltin struct, specifying the signature builtin memory structure. +struct SignatureBuiltin: + member pub_key = 0 + member message = 1 + const SIZE = 2 +end + +# A representation of a CheckpointsBuiltin struct, specifying the checkpoints builtin memory +# structure. +struct CheckpointsBuiltin: + member required_pc = 0 + member required_fp = 1 + const SIZE = 2 +end diff --git a/src/starkware/cairo/common/dict.cairo b/src/starkware/cairo/common/dict.cairo new file mode 100644 index 00000000..78349bf8 --- /dev/null +++ b/src/starkware/cairo/common/dict.cairo @@ -0,0 +1,258 @@ +struct DictAccess: + member key = 0 + member prev_value = 1 + member new_value = 2 + const SIZE = 3 +end + +# Inner tail-recursive function for squash_dict. +# +# Arguments: +# range_check_ptr - range check builtin pointer. +# dict_accesses - a pointer to the beginning of an array of DictAccess instances. +# dict_accesses_end_minus1 - a pointer to the end of said array, minus 1. +# min_key - minimum allowed key. Used to enforce monotonicity of keys. +# remaining_accesses - remaining number of accesses that need to be accounted for. Starts with +# the total number of entries in dict_accesses array, and slowly decreases until it reaches 0. +# squashed_dict - a pointer to an output array, which will be filled with +# DictAccess instances sorted by key with the first and last value for each key. +# +# Hints: +# keys - a descending list of the keys for which we have accesses. Destroyed in the process. +# access_indices - A map from key to a descending list of indices in the dict_accesses array that +# access this key. Destroyed in the process. +# +# Returns: +# range_check_ptr - updated range check builtin pointer. +# squashed_dict - end pointer to squashed_dict. +func squash_dict_inner( + range_check_ptr, dict_accesses : DictAccess*, dict_accesses_end_minus1 : felt*, min_key, + remaining_accesses, squashed_dict : DictAccess*) -> ( + range_check_ptr, squashed_dict : DictAccess*): + # Exit recursion when done. + if remaining_accesses == 0: + %{ assert len(keys) == 0 %} + return (range_check_ptr=range_check_ptr, squashed_dict=squashed_dict) + end + + # Locals. + struct Locals: + member key = 0 + member should_skip_loop = 1 + member first_value = 2 + const SIZE = 3 + end + let locals = cast(fp, Locals*) + let key = locals.key + let dict_diff : DictAccess* = squashed_dict + ap += Locals.SIZE + + # Guess key and check that key >= min_key. + %{ ids.locals.key = key = keys.pop() %} + [ap] = key - min_key + [ap] = [range_check_ptr]; ap++ + + # Loop to verify chronological accesses to the key. + # These values are not needed from previous iteration. + struct LoopTemps: + member index_delta_minus1 = 0 + member index_delta = 1 + member ptr_delta = 2 + member should_continue = 3 + const SIZE = 4 + end + # These values are needed from previous iteration. + struct LoopLocals: + member value = 0 + member access_ptr : DictAccess* = 1 + member range_check_ptr = 2 + const SIZE = 3 + end + + # Prepare first iteration. + %{ + current_access_indices = sorted(access_indices[key])[::-1] + current_access_index = current_access_indices.pop() + memory[ids.range_check_ptr + 1] = current_access_index + %} + # Check that first access_index >= 0. + tempvar current_access_index = [range_check_ptr + 1] + tempvar ptr_delta = current_access_index * DictAccess.SIZE + + let first_loop_locals = cast(ap, LoopLocals*) + first_loop_locals.access_ptr = dict_accesses + ptr_delta; ap++ + let first_access : DictAccess* = first_loop_locals.access_ptr + first_loop_locals.value = first_access.new_value; ap++ + first_loop_locals.range_check_ptr = range_check_ptr + 2; ap++ + + # Verify first key. + key = first_access.key + + # Write key and first value to dict_diff. + key = dict_diff.key + # Use a local variable, instead of a tempvar, to avoid increasing ap. + locals.first_value = first_access.prev_value + locals.first_value = dict_diff.prev_value + + # Skip loop non-deterministically if necessary. + %{ memory[fp + ids.Locals.should_skip_loop] = 0 if current_access_indices else 1 %} + jmp skip_loop if [fp + Locals.should_skip_loop] != 0 + + loop: + let prev_loop_locals = cast(ap - LoopLocals.SIZE, LoopLocals*) + let loop_temps = cast(ap, LoopTemps*) + let loop_locals = cast(ap + LoopTemps.SIZE, LoopLocals*) + + # Check access_index. + %{ + new_access_index = current_access_indices.pop() + ids.loop_temps.index_delta_minus1 = new_access_index - current_access_index - 1 + current_access_index = new_access_index + %} + # Check that new access_index > prev access_index. + loop_temps.index_delta_minus1 = [prev_loop_locals.range_check_ptr]; ap++ + loop_temps.index_delta = loop_temps.index_delta_minus1 + 1; ap++ + loop_temps.ptr_delta = loop_temps.index_delta * DictAccess.SIZE; ap++ + loop_locals.access_ptr = prev_loop_locals.access_ptr + loop_temps.ptr_delta; ap++ + + # Check valid transition. + let access : DictAccess* = loop_locals.access_ptr + prev_loop_locals.value = access.prev_value + loop_locals.value = access.new_value; ap++ + + # Verify key. + key = access.key + + # Next range_check_ptr. + loop_locals.range_check_ptr = prev_loop_locals.range_check_ptr + 1; ap++ + + %{ ids.loop_temps.should_continue = 1 if current_access_indices else 0 %} + jmp loop if loop_temps.should_continue != 0; ap++ + + skip_loop: + let last_loop_locals = cast(ap - LoopLocals.SIZE, LoopLocals*) + + # Check if address is out of bounds. + %{ assert len(current_access_indices) == 0 %} + [ap] = dict_accesses_end_minus1 - cast(last_loop_locals.access_ptr, felt) + [ap] = [last_loop_locals.range_check_ptr]; ap++ + tempvar range_check_diff = last_loop_locals.range_check_ptr - range_check_ptr + tempvar n_used_accesses = range_check_diff - 1 + %{ assert ids.n_used_accesses == len(access_indices[key]) %} + + # Write last value to dict_diff. + last_loop_locals.value = dict_diff.new_value + + # Call squashed_dict_inner recursively. + squash_dict_inner( + range_check_ptr=last_loop_locals.range_check_ptr + 1, + dict_accesses=dict_accesses, + dict_accesses_end_minus1=dict_accesses_end_minus1, + min_key=key + 1, + remaining_accesses=remaining_accesses - n_used_accesses, + squashed_dict=squashed_dict + DictAccess.SIZE) + return (...) +end + +# Verifies that dict_accesses lists valid chronological accesses (and updates) +# to a mutable dictionary and outputs a squashed dict with one DictAccess instance per key +# (value before and value after) which summarizes all the changes to that key. +# +# All keys are assumed to be in the range of the range check builtin (usually 2**128). +# +# Example: +# Input: {(key1, 0, 2), (key1, 2, 7), (key2, 4, 1), (key1, 7, 5), (key2, 1, 2)} +# Output: {(key1, 0, 5), (key2, 4, 2)} +# +# Arguments: +# range_check_ptr - range check builtin pointer. +# dict_accesses - a pointer to the beginning of an array of DictAccess instances. The format of each +# entry is a triplet (key, prev_value, new_value). +# dict_accesses_end - a pointer to the end of said array. +# squashed_dict - a pointer to an output array, which will be filled with +# DictAccess instances sorted by key with the first and last value for each key. +# +# Returns: +# range_check_ptr - updated range check builtin pointer. +# squashed_dict - end pointer to squashed_dict. +func squash_dict( + range_check_ptr, dict_accesses : DictAccess*, dict_accesses_end : DictAccess*, + squashed_dict : DictAccess*) -> (range_check_ptr, squashed_dict : DictAccess*): + let ptr_diff = [fp] + %{ vm_enter_scope() %} + ptr_diff = dict_accesses_end - dict_accesses; ap++ + + if ptr_diff == 0: + # Access array is empty, nothing to check. + %{ vm_exit_scope() %} + return (range_check_ptr=range_check_ptr, squashed_dict=squashed_dict) + end + + tempvar n_accesses = ptr_diff / DictAccess.SIZE + %{ + assert ids.ptr_diff % ids.DictAccess.SIZE == 0, \ + 'Accesses array size must be divisible by DictAccess.SIZE' + # A map from key to the list of indices accessing it. + access_indices = {} + for i in range(ids.n_accesses): + key = memory[ids.dict_accesses.address_ + ids.DictAccess.SIZE * i] + access_indices.setdefault(key, []).append(i) + # Descending list of keys. + keys = sorted(access_indices.keys())[::-1] + %} + + # Call inner. + squash_dict_inner( + range_check_ptr=range_check_ptr, + dict_accesses=dict_accesses, + dict_accesses_end_minus1=dict_accesses_end - 1, + min_key=0, + remaining_accesses=n_accesses, + squashed_dict=squashed_dict) + %{ vm_exit_scope() %} + return (...) +end + +# Initializes the dict manager. Should be called exactly once at the beginning of a program that +# uses dicts. +func initialize_dict_manager() -> (): + %{ + from starkware.cairo.common.dict import DictManager + assert 'dict_manager' not in globals(), \ + 'initialize_dict_manager() must be called exactly once.' + dict_manager = DictManager() + %} + return () +end + +# Creates a new dict. +# Note that a dict_manager must be passed in the hints. +# Allocate one using initialize_dict_manager(). +func dict_new() -> (res): + %{ memory[ap] = dict_manager.new_dict(segments, initial_dict) %} + ap += 1 + return (...) +end + +# Updates a value in a dict. prev_value must be specified. A standalone read with no write should be +# performed by writing the same value. +# It is possible to get prev_value from dict_manager using the hint: +# %{ ids.val = dict_manager.get_dict(ids.dict_ptr)[ids.key] %} +func dict_update(dict_ptr : DictAccess*, key : felt, prev_value : felt, new_value : felt) -> ( + dict_ptr : DictAccess*): + %{ + # Verify dict pointer and prev value. + dict_tracker = dict_manager.get_tracker(ids.dict_ptr) + current_value = dict_tracker.data[ids.key] + assert current_value == ids.prev_value, \ + f'Wrong previous value in dict. Got {ids.prev_value}, expected {current_value}.' + + # Update value. + dict_tracker.data[ids.key] = ids.new_value + dict_tracker.current_ptr += ids.DictAccess.SIZE + %} + dict_ptr.key = key + dict_ptr.prev_value = prev_value + dict_ptr.new_value = new_value + return (dict_ptr=dict_ptr + DictAccess.SIZE) +end diff --git a/src/starkware/cairo/common/dict.py b/src/starkware/cairo/common/dict.py new file mode 100644 index 00000000..280a85b8 --- /dev/null +++ b/src/starkware/cairo/common/dict.py @@ -0,0 +1,58 @@ +import dataclasses +from typing import Dict + +from starkware.cairo.lang.vm.relocatable import RelocatableValue +from starkware.cairo.lang.vm.vm_consts import VmConstsReference + + +@dataclasses.dataclass +class DictTracker: + """ + Tracks the python dict associated with a Cairo dict. + """ + # Python dict. + data: dict + # Pointer to the first unused position in the dict segment. + current_ptr: RelocatableValue + + +class DictManager: + """ + Manages dictionaries in a Cairo program. + Uses the segment index to associate the corresponding python dict with the Cairo dict. + """ + + def __init__(self): + # Mapping from segment index to the corresponding DictTracker of the Cairo dict. + self.trackers: Dict[int, DictTracker] = {} + + def new_dict(self, segments, initial_dict): + """ + Creates a new Cairo dictionary. The values of initial_dict can be integers, tuples or + lists. See MemorySegments.gen_arg(). + """ + base = segments.add() + assert base.segment_index not in self.trackers + self.trackers[base.segment_index] = DictTracker( + data={ + key: segments.gen_arg(value) for key, value in initial_dict.items()}, + current_ptr=base, + ) + return base + + def get_tracker(self, dict_ptr): + """ + Gets a dict tracker given the dict_ptr. + """ + if isinstance(dict_ptr, VmConstsReference): + dict_ptr = dict_ptr.address_ + dict_tracker = self.trackers[dict_ptr.segment_index] + assert dict_tracker.current_ptr == dict_ptr, 'Wrong dict pointer supplied. ' \ + f'Got {dict_ptr}, expected {dict_tracker.current_ptr}.' + return dict_tracker + + def get_dict(self, dict_ptr) -> dict: + """ + Gets the python dict that corresponds to dict_ptr. + """ + return self.get_tracker(dict_ptr).data diff --git a/src/starkware/cairo/common/find_element.cairo b/src/starkware/cairo/common/find_element.cairo new file mode 100644 index 00000000..b61dbf13 --- /dev/null +++ b/src/starkware/cairo/common/find_element.cairo @@ -0,0 +1,99 @@ +from starkware.cairo.common.math import assert_nn_le +from starkware.cairo.common.math import assert_le + +# Finds an element in the array whose first field is key and returns a pointer +# to this element. +# Since cairo is non-deterministic this is an O(1) operation. +# Note however that if the array has multiple elements with said key the function may return any +# of those elements. +# +# Arguments: +# range_check_ptr - range check builtin pointer. +# array_ptr - pointer to an array. +# elm_size - size of an element in the array. +# n_elms - number of element in the array. +# key - key to look for. +# +# Returns: +# range_check_ptr - new range check builtin pointer. +# elm_ptr - pointer to an element in the array satisfying [ptr] = key. +func find_element(range_check_ptr, array_ptr : felt*, elm_size, n_elms, key) -> ( + range_check_ptr, elm_ptr : felt*): + alloc_locals + local index + %{ + for i in range(ids.n_elms): + if memory[ids.array_ptr + ids.elm_size * i] == ids.key: + ids.index = i + break + else: + raise ValueError(f'Key {ids.key} not found.') + %} + + let (range_check_ptr) = assert_nn_le(range_check_ptr=range_check_ptr, a=index, b=n_elms - 1) + tempvar elm_ptr = array_ptr + elm_size * index + assert [elm_ptr] = key + return (range_check_ptr=range_check_ptr, elm_ptr=elm_ptr) +end + +# Given an array sorted by its first field, returns the pointer to the first element in the array +# whose first field is at least key. If no such item exists, returns a pointer to the end of the +# array. +# Prover assumption: all the keys (the first field in each item) are in [0, RANGE_CHECK_BOUND). +func search_sorted_lower(range_check_ptr, array_ptr : felt*, elm_size, n_elms, key) -> ( + range_check_ptr, elm_ptr : felt*): + alloc_locals + local index + %{ + for i in range(ids.n_elms): + if memory[ids.array_ptr + ids.elm_size * i] >= ids.key: + ids.index = i + break + else: + ids.index = ids.n_elms + %} + + let (range_check_ptr) = assert_nn_le(range_check_ptr=range_check_ptr, a=index, b=n_elms) + local elm_ptr : felt* = array_ptr + elm_size * index + + local range_check_ptr1 + if index != n_elms: + let (range_check_ptr) = assert_le(range_check_ptr=range_check_ptr, a=key, b=[elm_ptr]) + range_check_ptr1 = range_check_ptr + else: + range_check_ptr1 = range_check_ptr + end + + local range_check_ptr2 + if index != 0: + let (range_check_ptr) = assert_le( + range_check_ptr=range_check_ptr1, a=[elm_ptr - elm_size] + 1, b=key) + range_check_ptr2 = range_check_ptr + else: + range_check_ptr2 = range_check_ptr1 + end + + return (range_check_ptr=range_check_ptr2, elm_ptr=elm_ptr) +end + +# Given an array sorted by its first field, returns the pointer to the first element in the array +# whose first field is exactly key. If no such item exists, returns an undefined pointer, +# and success=0. +# Prover assumption: all the keys (the first field in each item) are in [0, RANGE_CHECK_BOUND). +func search_sorted(range_check_ptr, array_ptr : felt*, elm_size, n_elms, key) -> ( + range_check_ptr, elm_ptr : felt*, success): + let (range_check_ptr, elm_ptr) = search_sorted_lower( + range_check_ptr=range_check_ptr, + array_ptr=array_ptr, + elm_size=elm_size, + n_elms=n_elms, + key=key) + tempvar array_end = array_ptr + elm_size * n_elms + if elm_ptr == array_end: + return (range_check_ptr=range_check_ptr, elm_ptr=array_ptr, success=0) + end + if [elm_ptr] != key: + return (range_check_ptr=range_check_ptr, elm_ptr=array_ptr, success=0) + end + return (range_check_ptr=range_check_ptr, elm_ptr=elm_ptr, success=1) +end diff --git a/src/starkware/cairo/common/hash.cairo b/src/starkware/cairo/common/hash.cairo new file mode 100644 index 00000000..81bc5c39 --- /dev/null +++ b/src/starkware/cairo/common/hash.cairo @@ -0,0 +1,16 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin + +# Computes the Pedersen hash of two given field elements. +# +# Arguments: +# pedersen_ptr - the Pedersen hash builtin pointer. +# x, y - the two field elements to be hashed, in this order. +# +# Returns: +# pedersen_ptr - the new Pedersen builtin pointer. +# result - the field element result of the hash. +func pedersen_hash(pedersen_ptr : HashBuiltin*, x, y) -> (pedersen_ptr : HashBuiltin*, result): + pedersen_ptr.x = x + pedersen_ptr.y = y + return (pedersen_ptr=pedersen_ptr + HashBuiltin.SIZE, result=pedersen_ptr.result) +end diff --git a/src/starkware/cairo/common/hash_chain.cairo b/src/starkware/cairo/common/hash_chain.cairo new file mode 100644 index 00000000..73e2ba99 --- /dev/null +++ b/src/starkware/cairo/common/hash_chain.cairo @@ -0,0 +1,51 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin + +# Computes a hash chain of a sequence whose length is given at [data_ptr] and the data starts at +# data_ptr + 1. The hash is calculated backwards (from the highest memory address to the lowest). +# For example, for the 3-element sequence [x, y, z] the hash is: +# h(3, h(x, h(y, z))) +# If data_length = 0, the function does not return (takes more than field prime steps). +func hash_chain(pedersen_ptr, data_ptr : felt*) -> (pedersen_ptr, hash): + struct LoopLocals: + member data_ptr : felt* = 0 + member pedersen_ptr : HashBuiltin* = 1 + member cur_hash = 2 + const SIZE = 3 + end + + let data_length = ap + [data_length] = [data_ptr]; ap++ + let loop_frame = cast(ap, LoopLocals*) + + # Prepare the loop_frame for the first iteration of the hash_loop. + loop_frame.data_ptr = data_ptr + [data_length]; ap++ + loop_frame.pedersen_ptr = pedersen_ptr; ap++ + loop_frame.cur_hash = [loop_frame.data_ptr]; ap++ + + hash_loop: + let curr_frame = cast(ap - LoopLocals.SIZE, LoopLocals*) + let current_hash : HashBuiltin* = curr_frame.pedersen_ptr + + let new_data_ptr = curr_frame.data_ptr - 1 + let new_data = ap + [new_data] = [new_data_ptr]; ap++ + + let n_elements_to_hash = ap + # Assign current_hash inputs and allocate space for n_elements_to_hash. + [new_data] = current_hash.x; ap++ + curr_frame.cur_hash = current_hash.y + + # Set the frame for the next loop iteration (going backwards). + let next_frame = cast(ap, LoopLocals*) + next_frame.data_ptr = new_data_ptr; ap++ + next_frame.pedersen_ptr = curr_frame.pedersen_ptr + HashBuiltin.SIZE; ap++ + next_frame.cur_hash = current_hash.result; ap++ + + # Update n_elements_to_hash and loop accordingly. Note that the hash is calculated backwards. + [n_elements_to_hash] = next_frame.data_ptr - data_ptr + jmp hash_loop if [n_elements_to_hash] != 0 + + # Note that the function return values (pedersen_ptr, hash) are at the end of next_frame so + # they are already located in the correct location. + ret +end diff --git a/src/starkware/cairo/common/hash_chain.py b/src/starkware/cairo/common/hash_chain.py new file mode 100644 index 00000000..33a1a13c --- /dev/null +++ b/src/starkware/cairo/common/hash_chain.py @@ -0,0 +1,12 @@ +from functools import reduce + +from starkware.cairo.lang.vm.crypto import pedersen_hash + + +def compute_hash_chain(data, hash_func=pedersen_hash): + """ + Computes a hash chain over the data, in the following order: + h(data[0], h(data[1], h(..., h(data[n-2], data[n-1])))). + """ + + return reduce(lambda x, y: hash_func(y, x), data[::-1]) diff --git a/src/starkware/cairo/common/hash_state.cairo b/src/starkware/cairo/common/hash_state.cairo new file mode 100644 index 00000000..ad9d086f --- /dev/null +++ b/src/starkware/cairo/common/hash_state.cairo @@ -0,0 +1,77 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.hash import pedersen_hash +from starkware.cairo.common.registers import get_fp_and_pc + +# Stores the hash of a sequence of items. New items can be added to the hash state using hash_update +# and hash_update_single. The final hash of the entire sequence, including the sequence length, can +# be extracted using hash_finalize. +# For example, the hash of the sequence (x, y, z) is h(h(h(h(0, x), y), z), 3). +# In particular, the hash of zero items is h(0, 0). +struct HashState: + member current_hash = 0 + member n_words = 1 + const SIZE = 2 +end + +# Initializes a new HashState with no items. +func hash_init() -> (hash_state_ptr : HashState*): + alloc_locals + let (__fp__, _) = get_fp_and_pc() + local hash_state : HashState + hash_state.current_hash = 0 + hash_state.n_words = 0 + return (hash_state_ptr=&hash_state) +end + +# A helper function for 'hash_update', see its documentaion. +# Computes the hash of an array of items, not including its length. +func hash_update_inner(pedersen_ptr : HashBuiltin*, curr_ptr : felt*, data_length, hash) -> ( + pedersen_ptr : HashBuiltin*, hash): + if data_length == 0: + return (pedersen_ptr=pedersen_ptr, hash=hash) + end + + let (pedersen_ptr, hash) = pedersen_hash(pedersen_ptr=pedersen_ptr, x=hash, y=[curr_ptr]) + let (pedersen_ptr, hash) = hash_update_inner( + pedersen_ptr=pedersen_ptr, curr_ptr=curr_ptr + 1, data_length=data_length - 1, hash=hash) + return (...) +end + +# Adds each item in an array of items to the HashState. +# The array is represented by a pointer and a length. +func hash_update( + pedersen_ptr : HashBuiltin*, hash_state_ptr : HashState*, data_ptr : felt*, + data_length) -> (pedersen_ptr : HashBuiltin*, new_hash_state_ptr : HashState*): + alloc_locals + let (pedersen_ptr, hash) = hash_update_inner( + pedersen_ptr=pedersen_ptr, + curr_ptr=data_ptr, + data_length=data_length, + hash=hash_state_ptr.current_hash) + let (__fp__, _) = get_fp_and_pc() + local new_hash_state : HashState + new_hash_state.current_hash = hash + assert new_hash_state.n_words = hash_state_ptr.n_words + data_length + return (pedersen_ptr=pedersen_ptr, new_hash_state_ptr=&new_hash_state) +end + +# Adds a single item to the HashState. +func hash_update_single(pedersen_ptr : HashBuiltin*, hash_state_ptr : HashState*, item) -> ( + pedersen_ptr : HashBuiltin*, new_hash_state_ptr : HashState*): + alloc_locals + let (pedersen_ptr, hash) = pedersen_hash( + pedersen_ptr=pedersen_ptr, x=hash_state_ptr.current_hash, y=item) + let (__fp__, _) = get_fp_and_pc() + local new_hash_state : HashState + new_hash_state.current_hash = hash + assert new_hash_state.n_words = hash_state_ptr.n_words + 1 + return (pedersen_ptr=pedersen_ptr, new_hash_state_ptr=&new_hash_state) +end + +# Returns the hash result of the HashState. +func hash_finalize(pedersen_ptr : HashBuiltin*, hash_state_ptr : HashState*) -> ( + pedersen_ptr : HashBuiltin*, hash): + pedersen_hash( + pedersen_ptr=pedersen_ptr, x=hash_state_ptr.current_hash, y=hash_state_ptr.n_words) + return (...) +end diff --git a/src/starkware/cairo/common/math.cairo b/src/starkware/cairo/common/math.cairo new file mode 100644 index 00000000..32fe7cb9 --- /dev/null +++ b/src/starkware/cairo/common/math.cairo @@ -0,0 +1,262 @@ +# Inline functions with no locals. + +# Verifies that value != 0. The proof will fail otherwise. +func assert_not_zero(value): + %{ assert ids.value % PRIME != 0, f'assert_not_zero failed: {ids.value} = 0.' %} + if value == 0: + # If value == 0, add an unsatisfiable requirement. + value = 1 + end + + return () +end + +# Verifies that a != b. The proof will fail otherwise. +func assert_not_equal(a, b): + %{ assert (ids.a - ids.b) % PRIME != 0, f'assert_not_equal failed: {ids.a} = {ids.b}.' %} + if a == b: + # If a == b, add an unsatisfiable requirement. + [fp - 1] = [fp - 1] + 1 + end + + return () +end + +# Verifies that a >= 0 (or more precisely 0 <= a < RANGE_CHECK_BOUND). +func assert_nn(range_check_ptr, a) -> (range_check_ptr): + %{ assert 0 <= ids.a % PRIME < range_check_builtin.bound, f'a = {ids.a} is out of range.' %} + a = [range_check_ptr] + return (range_check_ptr + 1) +end + +# Verifies that a <= b (or more precisely 0 <= b - a < RANGE_CHECK_BOUND). +func assert_le(range_check_ptr, a, b) -> (range_check_ptr): + let (range_check_ptr) = assert_nn(range_check_ptr, b - a) + return (range_check_ptr) +end + +# Verifies that a <= b - 1 (or more precisely 0 <= b - 1 - a < RANGE_CHECK_BOUND). +func assert_lt(range_check_ptr, a, b) -> (range_check_ptr): + let (range_check_ptr) = assert_le(range_check_ptr, a, b - 1) + return (range_check_ptr) +end + +# Verifies that 0 <= a <= b. +# +# Prover assumption: a, b < RANGE_CHECK_BOUND. +func assert_nn_le(range_check_ptr, a, b) -> (range_check_ptr): + let (range_check_ptr) = assert_nn(range_check_ptr, a) + let (range_check_ptr) = assert_le(range_check_ptr, a, b) + return (range_check_ptr) +end + +# Asserts that value is in the range [lower, upper). +func assert_in_range(range_check_ptr, value, lower, upper) -> (range_check_ptr): + let (range_check_ptr) = assert_le(range_check_ptr, lower, value) + let (range_check_ptr) = assert_le(range_check_ptr, value, upper - 1) + return (range_check_ptr) +end + +# Asserts that a <= b. +# +# Assumptions: +# a and b are in the range [0, 2**250). +# PRIME - 2**250 > 2**(250 - 128) + 1 * RC_BOUND. +func assert_le_250_bit(range_check_ptr, a, b) -> (range_check_ptr): + let low = [range_check_ptr] + let high = [range_check_ptr + 1] + const UPPER_BOUND = %[2**(250)%] + const HIGH_PART_SHIFT = %[2**250 // 2**128 %] + %{ + # Soundness checks. + assert range_check_builtin.bound == 2**128 + assert ids.UPPER_BOUND == ids.HIGH_PART_SHIFT * range_check_builtin.bound + assert ids.a < ids.UPPER_BOUND, f'a={ids.a} is outside of the valid range.' + assert ids.b < ids.UPPER_BOUND, f'b={ids.b} is outside of the valid range.' + assert PRIME - ids.UPPER_BOUND > (ids.HIGH_PART_SHIFT + 1) * range_check_builtin.bound + + # Correctness check. + assert ids.a <= ids.b, f'a={ids.a} > b={ids.b}.' + %} + + tempvar diff = b - a + %{ + ids.high = ids.diff // ids.HIGH_PART_SHIFT + ids.low = ids.diff % ids.HIGH_PART_SHIFT + %} + + # Assuming the assert below, we have + # diff = high * HIGH_PART_SHIFT + low < (HIGH_PART_SHIFT + 1) * RC_BOUND < PRIME - UPPER_BOUND. + # If 0 <= b < a < UPPER_BOUND then diff < 0 => diff % P = PRIME - diff > PRIME - UPPER_BOUND. + # So given the soundness assumptions listed above it must be the case that a <= b. + assert diff = high * HIGH_PART_SHIFT + low + + return (range_check_ptr=range_check_ptr + 2) +end + +# Splits the unsigned integer lift of a field element into the higher 128 bit and lower 128 bit. +# The unsigned integer lift is the unique integer in the range [0, PRIME) that represents the field +# element. +# For example, if value=17 * 2^128 + 8, then high=17 and low=8. +func split_felt(range_check_ptr, value) -> (range_check_ptr, high, low): + const MAX_HIGH = %[(PRIME - 1) >> 128%] + const MAX_LOW = %[(PRIME - 1) & ((1 << 128) - 1)%] + + # Guess the low and high parts of the integer. + let low = [range_check_ptr] + let high = [range_check_ptr + 1] + let range_check_ptr = range_check_ptr + 2 + + %{ + assert PRIME < 2**256 + ids.low = ids.value & ((1 << 128) - 1) + ids.high = ids.value >> 128 + %} + assert value = high * %[2**128%] + low + if high == MAX_HIGH: + let (range_check_ptr) = assert_le(range_check_ptr, low, MAX_LOW) + else: + let (range_check_ptr) = assert_le(range_check_ptr, high, MAX_HIGH) + end + return (range_check_ptr=range_check_ptr, high=high, low=low) +end + +# Asserts that the unsigned integer lift (as a number in the range [0, PRIME)) of a is lower than +# or equal to that of b. +# See split_felt() for more details. +func assert_le_felt(range_check_ptr, a, b) -> (range_check_ptr): + %{ + assert (ids.a % PRIME) <= (ids.b % PRIME), \ + f'a = {ids.a % PRIME} is not less than or equal to b = {ids.b % PRIME}.' + %} + alloc_locals + let (range_check_ptr, local a_high, local a_low) = split_felt(range_check_ptr, a) + let (range_check_ptr, b_high, b_low) = split_felt(range_check_ptr, b) + + if a_high == b_high: + let (range_check_ptr) = assert_le(range_check_ptr, a_low, b_low) + return (range_check_ptr) + end + let (range_check_ptr) = assert_le(range_check_ptr, a_high, b_high) + return (range_check_ptr) +end + +# Asserts that the unsigned integer lift (as a number in the range [0, PRIME)) of a is lower than +# that of b. +func assert_lt_felt(range_check_ptr, a, b) -> (range_check_ptr): + %{ + assert (ids.a % PRIME) < (ids.b % PRIME), \ + f'a = {ids.a % PRIME} is not less than b = {ids.b % PRIME}.' + %} + alloc_locals + let (range_check_ptr, local a_high, local a_low) = split_felt(range_check_ptr, a) + let (range_check_ptr, b_high, b_low) = split_felt(range_check_ptr, b) + + if a_high == b_high: + let (range_check_ptr) = assert_lt(range_check_ptr, a_low, b_low) + return (range_check_ptr) + end + let (range_check_ptr) = assert_lt(range_check_ptr, a_high, b_high) + return (range_check_ptr) +end + +# Returns the absolute value of value. +# Prover asumption: -rc_bound < value < rc_bound. +func abs_value(range_check_ptr, value) -> (range_check_ptr, abs_value): + %{ + from starkware.cairo.common.math_utils import is_positive + memory[ap] = 1 if is_positive( + value=ids.value, prime=PRIME, rc_bound=range_check_builtin.bound) else 0 + %} + jmp is_positive if [ap] != 0; ap++ + [ap] = range_check_ptr + 1; ap++ # range_check_ptr + [ap] = value * (-1); ap++ # abs_value + [range_check_ptr] = [ap - 1] + return (...) + + is_positive: + [range_check_ptr] = value + return (range_check_ptr=range_check_ptr + 1, abs_value=value) +end + +# Returns the sign of value: -1, 0 or 1. +# Prover asumption: -rc_bound < value < rc_bound. +func sign(range_check_ptr, value) -> (range_check_ptr, sign): + if value == 0: + return (range_check_ptr=range_check_ptr, sign=0) + end + + %{ + from starkware.cairo.common.math_utils import is_positive + memory[ap] = 1 if is_positive( + value=ids.value, prime=PRIME, rc_bound=range_check_builtin.bound) else 0 + %} + jmp is_positive if [ap] != 0; ap++ + assert [range_check_ptr] = value * (-1) + return (range_check_ptr=range_check_ptr + 1, sign=-1) + + is_positive: + [range_check_ptr] = value + return (range_check_ptr=range_check_ptr + 1, sign=1) +end + +# Returns q and r such that: +# 0 <= q < rc_bound, 0 <= r < div and value = q * div + r. +# +# Assumption: 0 < div <= PRIME / rc_bound. +# Prover assumption: value / div < rc_bound. +# +# The value of div is restricted to make sure there is no overflow. +# q * div + r < (q + 1) * div <= rc_bound * (PRIME / rc_bound) = PRIME. +func unsigned_div_rem(range_check_ptr, value, div) -> (range_check_ptr, q, r): + let r = [range_check_ptr] + let q = [range_check_ptr + 1] + %{ + assert 0 < ids.div <= PRIME // range_check_builtin.bound, \ + f'div={hex(ids.div)} is out of the valid range.' + ids.q, ids.r = divmod(ids.value, ids.div) + %} + let (range_check_ptr) = assert_le(range_check_ptr + 2, r, div - 1) + + assert value = q * div + r + return (range_check_ptr, q, r) +end + +# Returns q and r such that. -bound <= q < bound, 0 <= r < div -1 and value = q * div + r. +# value < PRIME / 2 is considered positive and value > PRIME / 2 is considered negative. +# +# Assumptions: +# 0 < div <= PRIME / (rc_bound) +# bound <= rc_bound / 2. +# Prover assumption: -bound <= value / div < bound. + +# The values of div and bound are restricted to make sure there is no overflow. +# q * div + r < (q + 1) * div <= rc_bound / 2 * (PRIME / rc_bound) +# q * div + r >= q * div >= -rc_bound / 2 * (PRIME / rc_bound) +func signed_div_rem(range_check_ptr, value, div, bound) -> (range_check_ptr, q, r): + let r = [range_check_ptr] + let biased_q = [range_check_ptr + 1] # == q + bound. + %{ + def as_int(val): + return val if val < PRIME // 2 else val - PRIME + + assert 0 < ids.div <= PRIME // range_check_builtin.bound, \ + f'div={hex(ids.div)} is out of the valid range.' + + assert ids.bound <= range_check_builtin.bound // 2, \ + f'bound={hex(ids.bound)} is out of the valid range.' + + int_value = as_int(ids.value) + q, ids.r = divmod(int_value, ids.div) + + assert -ids.bound <= q < ids.bound, \ + f'{int_value} / {ids.div} = {q} is out of the range [{-ids.bound}, {ids.bound}).' + + ids.biased_q = q + ids.bound + %} + let q = biased_q - bound + assert value = q * div + r + let (range_check_ptr) = assert_le(range_check_ptr + 2, r, div - 1) + let (range_check_ptr) = assert_le(range_check_ptr, biased_q, 2 * bound - 1) + return (range_check_ptr, q, r) +end diff --git a/src/starkware/cairo/common/math_utils.py b/src/starkware/cairo/common/math_utils.py new file mode 100644 index 00000000..449c1f44 --- /dev/null +++ b/src/starkware/cairo/common/math_utils.py @@ -0,0 +1,17 @@ +def as_int(val, prime): + """ + Returns the lift of the given field element, val, as an integer in the range + (-prime/2, prime/2). + """ + return val if val < prime // 2 else val - prime + + +def is_positive(value, prime, rc_bound): + """ + Returns True if the lift of the given field element, as an integer in the range + (-rc_bound, rc_bound), is positive. + Raises an exception if the element is not within that range. + """ + val = as_int(value, prime) + assert abs(val) < rc_bound, f'value={val} is out of the valid range.' + return val > 0 diff --git a/src/starkware/cairo/common/memcpy.cairo b/src/starkware/cairo/common/memcpy.cairo new file mode 100644 index 00000000..5ae39f78 --- /dev/null +++ b/src/starkware/cairo/common/memcpy.cairo @@ -0,0 +1,38 @@ +# Copies len field elements from src to dst. +func memcpy(dst : felt*, src : felt*, len): + struct LoopFrame: + member dst : felt* = 0 + member src : felt* = 1 + const SIZE = 2 + end + + if len == 0: + return () + end + + let frame = cast(ap, LoopFrame*) + %{ vm_enter_scope({'n': ids.len}) %} + frame.dst = dst; ap++ + frame.src = src; ap++ + + loop: + let frame = cast(ap - LoopFrame.SIZE, LoopFrame*) + assert [frame.dst] = [frame.src] + + let continue_copying = [ap] + # Reserve space for continue_copying. + let next_frame = cast(ap + 1, LoopFrame*) + next_frame.dst = frame.dst + 1; ap++ + next_frame.src = frame.src + 1; ap++ + %{ + n -= 1 + ids.continue_copying = 1 if n > 0 else 0 + %} + static_assert next_frame + LoopFrame.SIZE == ap + 1 + jmp loop if continue_copying != 0; ap++ + # Assert that the loop executed len times. + len = cast(next_frame.src, felt) - cast(src, felt) + + %{ vm_exit_scope() %} + return () +end diff --git a/src/starkware/cairo/common/merkle_multi_update.cairo b/src/starkware/cairo/common/merkle_multi_update.cairo new file mode 100644 index 00000000..4154a741 --- /dev/null +++ b/src/starkware/cairo/common/merkle_multi_update.cairo @@ -0,0 +1,185 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.dict import DictAccess + +# Helper function for merkle_multi_update(). +func merkle_multi_update_inner( + hash_ptr : HashBuiltin*, update_ptr : DictAccess*, height, prev_root, new_root, index) -> ( + hash_ptr : HashBuiltin*, update_ptr : DictAccess*): + let hash0 : HashBuiltin* = hash_ptr + let hash1 : HashBuiltin* = hash_ptr + HashBuiltin.SIZE + %{ + if ids.height == 0: + assert node == ids.new_root, f'Expected node {ids.new_root}. Got {node}.' + case = 'leaf' + else: + prev_left, prev_right = preimage[ids.prev_root] + new_left, new_right = preimage[ids.new_root] + + left_child, right_child = node + if left_child is None: + assert right_child is not None, 'No updates in tree' + case = 'right' + elif right_child is None: + case = 'left' + else: + case = 'both' + + # Fill non deterministic hashes. + hash_ptr = ids.hash_ptr.address_ + memory[hash_ptr + 0 * ids.HashBuiltin.SIZE + ids.HashBuiltin.x] = prev_left + memory[hash_ptr + 0 * ids.HashBuiltin.SIZE + ids.HashBuiltin.y] = prev_right + memory[hash_ptr + 1 * ids.HashBuiltin.SIZE + ids.HashBuiltin.x] = new_left + memory[hash_ptr + 1 * ids.HashBuiltin.SIZE + ids.HashBuiltin.y] = new_right + + memory[ap] = int(case != 'right') + %} + jmp not_right if [ap] != 0; ap++ + + update_right: + prev_root = hash0.result + new_root = hash1.result + + # Make sure the same authentication path is used. + assert hash0.x = hash1.x + + # Call merkle_multi_update_inner recursively. + %{ vm_enter_scope(dict(node=right_child, preimage=preimage)) %} + merkle_multi_update_inner( + hash_ptr=hash_ptr + 2 * HashBuiltin.SIZE, + update_ptr=update_ptr, + height=height - 1, + prev_root=hash0.y, + new_root=hash1.y, + index=index * 2 + 1) + %{ vm_exit_scope() %} + return (...) + + not_right: + %{ memory[ap] = int(case != 'left') %} + jmp not_left if [ap] != 0; ap++ + + update_left: + prev_root = hash0.result + new_root = hash1.result + + # Make sure the same authentication path is used. + assert hash0.y = hash1.y + + # Call merkle_multi_update_inner recursively. + %{ vm_enter_scope(dict(node=left_child, preimage=preimage)) %} + merkle_multi_update_inner( + hash_ptr=hash_ptr + 2 * HashBuiltin.SIZE, + update_ptr=update_ptr, + height=height - 1, + prev_root=hash0.x, + new_root=hash1.x, + index=index * 2) + %{ vm_exit_scope() %} + return (...) + + not_left: + jmp update_both if height != 0 + + update_leaf: + # Note: height may underflow, but in order to reach 0 (which is verified here), we will need + # more steps than the field characteristic. The assumption is that it is not feasible. + + # Write the update. + let update : DictAccess* = update_ptr + %{ assert case == 'leaf' %} + index = update.key + prev_root = update.prev_value + new_root = update.new_value + + # Return values. + return (hash_ptr=hash_ptr, update_ptr=update + DictAccess.SIZE) + + update_both: + # Locals 0 and 1 are taken by non deterministic jumps. + let local_left_index = [fp + 2] + %{ assert case == 'both' %} + local_left_index = index * 2; ap++ + + prev_root = hash0.result + new_root = hash1.result + + # Update left. + %{ vm_enter_scope(dict(node=left_child, preimage=preimage)) %} + merkle_multi_update_inner( + hash_ptr=hash_ptr + 2 * HashBuiltin.SIZE, + update_ptr=update_ptr, + height=height - 1, + prev_root=hash0.x, + new_root=hash1.x, + index=index * 2) + %{ vm_exit_scope() %} + + # Update right. + # hash_ptr and update_ptr are already pushed. + # Push height to workaround one hint per line limitation. + [ap] = height - 1; ap++ # height. + %{ vm_enter_scope(dict(node=right_child, preimage=preimage)) %} + merkle_multi_update_inner(..., prev_root=hash0.y, new_root=hash1.y, index=local_left_index + 1) + %{ vm_exit_scope() %} + return (...) +end + +# Performs an efficient update of multiple leaves in a Merkle tree. +# +# Arguments: +# hash_ptr - hash builtin pointer. +# update_ptr - a list of DictAccess instances sorted by key (e.g., the result of squash_dict). +# height - height of merkle tree. +# prev_root - root value before the multi update. +# new_root - root value after the multi update. +# +# Hint arguments: +# preimage - a dictionary from the hash value of a merkle node to the pair of children values. +# +# Returns: +# hash_ptr - updated hash builtin pointer. +# +# Assumptions: The keys in the update_ptr list are unique and sorted. +# Guarantees: All the keys in the update_ptr list are < 2**height. +# +# Pseudocode: +# def diff(prev, new, height): +# if height == 0: return [(prev,new)] +# if prev.left==new.left: return diff(prev.right, new.right, height - 1) +# if prev.right==new.right: return diff(prev.left, new.left, height - 1) +# return diff(prev.left, new.left, height - 1) + \ +# diff(prev.right, new.right, height - 1) +func merkle_multi_update( + hash_ptr : HashBuiltin*, update_ptr : DictAccess*, n_updates, height, prev_root, + new_root) -> (hash_ptr : HashBuiltin*): + if n_updates == 0: + prev_root = new_root + return (hash_ptr=hash_ptr) + end + + %{ + from starkware.starkware_utils.merkle_tree.merkle_tree import build_update_tree + + # Build modifications list. + modifications = [] + for i in range(ids.n_updates): + curr_update_ptr = ids.update_ptr.address_ + i * ids.DictAccess.SIZE + modifications.append(( + memory[curr_update_ptr + ids.DictAccess.key], + memory[curr_update_ptr + ids.DictAccess.new_value])) + + node = build_update_tree(ids.height, modifications) + del modifications + vm_enter_scope(dict(node=node, preimage=preimage)) + %} + let ret_val = merkle_multi_update_inner( + hash_ptr=hash_ptr, + update_ptr=update_ptr, + height=height, + prev_root=prev_root, + new_root=new_root, + index=0) + assert ret_val.update_ptr = update_ptr + n_updates * DictAccess.SIZE + %{ vm_exit_scope() %} + return (hash_ptr=ret_val.hash_ptr) +end diff --git a/src/starkware/cairo/common/merkle_update.cairo b/src/starkware/cairo/common/merkle_update.cairo new file mode 100644 index 00000000..732a1e15 --- /dev/null +++ b/src/starkware/cairo/common/merkle_update.cairo @@ -0,0 +1,81 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin + +# Performs an update for a single leaf (index) in a Merkle tree (where 0 <= index < 2^height). +# Updates the leaf from prev_leaf to new_leaf, and returns the previous and new roots of the +# Merkle tree resulting from the change. +# In particular, given a secret authentication path (of the siblings of the nodes in the path from +# the root to the leaf), this function computes the roots twice - once with prev_leaf and once with +# new_leaf, where the verifier is guaranteed that the same authentication path is used. +func merkle_update(hash_ptr, height, prev_leaf, new_leaf, index) -> (prev_root, new_root, hash_ptr): + if height == 0: + # Assert that index is 0. + index = 0 + # Return the two leaves and the Pedersen pointer. + %{ + # Check that auth_path had the right number of elements. + assert len(auth_path) == 0, 'Got too many values in auth_path.' + %} + return (prev_root=prev_leaf, new_root=new_leaf, hash_ptr=hash_ptr) + end + + %{ memory[ap] = ids.index % 2 %} + jmp update_right if [ap] != 0; ap++ + + update_left: + %{ + # Hash hints. + sibling = auth_path.pop() + memory[ids.hash_ptr + 0 * ids.HashBuiltin.SIZE + ids.HashBuiltin.y] = sibling + memory[ids.hash_ptr + 1 * ids.HashBuiltin.SIZE + ids.HashBuiltin.y] = sibling + %} + prev_leaf = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.x] + new_leaf = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.x] + + # Make sure the same authentication path is used. + let right_sibling = ap + [right_sibling] = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.y] + [right_sibling] = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.y]; ap++ + + # Call merkle_update recursively. + [ap] = hash_ptr + 2 * HashBuiltin.SIZE; ap++ # hash_ptr. + [ap] = height - 1; ap++ # height. + [ap] = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.result]; ap++ # prev_leaf. + [ap] = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.result]; ap++ # new_leaf. + + let update_left_index = ap + %{ memory[ap] = ids.index // 2 %} + index = [update_left_index] * 2; ap++ # index. + merkle_update(...) # Tail recursion. + return (...) + + update_right: + %{ + # Hash hints. + sibling = auth_path.pop() + memory[ids.hash_ptr + 0 * ids.HashBuiltin.SIZE + ids.HashBuiltin.x] = sibling + memory[ids.hash_ptr + 1 * ids.HashBuiltin.SIZE + ids.HashBuiltin.x] = sibling + %} + prev_leaf = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.y] + new_leaf = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.y] + + # Make sure the same authentication path is used. + let left_sibling = ap + [left_sibling] = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.x] + [left_sibling] = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.x]; ap++ + + # Compute index - 1. + tempvar index_minus_one = index - 1 + + # Call merkle_update recursively. + [ap] = hash_ptr + 2 * HashBuiltin.SIZE; ap++ # hash_ptr. + [ap] = height - 1; ap++ # height. + [ap] = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.result]; ap++ # prev_leaf. + [ap] = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.result]; ap++ # new_leaf. + + let update_right_index = ap + %{ memory[ap] = ids.index // 2 %} + # Compute (index - 1) / 2. + index_minus_one = [update_right_index] * 2; ap++ # index. + merkle_update(...) # Tail recursion. + return (...) +end diff --git a/src/starkware/cairo/common/registers.cairo b/src/starkware/cairo/common/registers.cairo new file mode 100644 index 00000000..e9ba950c --- /dev/null +++ b/src/starkware/cairo/common/registers.cairo @@ -0,0 +1,49 @@ +# Returns the contents of the fp and pc registers of the calling function. +# The pc register's value is the address of the instruction that follows directly after the +# invocation of get_fp_and_pc(). +func get_fp_and_pc() -> (fp_val, pc_val): + # The call instruction itself already places the old fp and the return pc at [fp - 2], [fp - 1]. + # Thus, we can simply return, and the calling function may regard these as the return values + # of this function. + return (...) +end + +# Returns the content of the ap register just before this function was invoked. +func get_ap() -> (ap_val): + # Once get_ap() is invoked, fp points to ap + 2 (since the call instruction placed the old fp + # and pc in memory, advancing ap accordingly). + # Calling dummy_func places fp and pc at [fp], [fp + 1] (respectively), and advances ap by 2. + # Hence, going two cells above we get [fp] = ap + 2, and by subtracting 2 we get the desired ap + # value. + call dummy_func + return (ap_val=[ap - 2] - 2) +end + +func dummy_func(): + return () +end + +# Takes the value of a label (relative to program base) and returns the actual runtime address of +# that label in the memory. +# +# Example usage: +# +# func do_callback(...): +# ... +# end +# +# func do_thing_then_callback(callback): +# ... +# call abs callback +# end +# +# func main(): +# let (callback_address) = get_label_location(do_callback) +# do_thing_then_callback(callback=callback_address) +# end +func get_label_location(label_value) -> (res): + let (_, pc_val) = get_fp_and_pc() + + ret_pc_label: + return (res=label_value + pc_val - ret_pc_label) +end diff --git a/src/starkware/cairo/common/serialize.cairo b/src/starkware/cairo/common/serialize.cairo new file mode 100644 index 00000000..b7d92c5b --- /dev/null +++ b/src/starkware/cairo/common/serialize.cairo @@ -0,0 +1,48 @@ +# Appends a single word to the output pointer, and returns the pointer to the next output cell. +func serialize_word(output_ptr : felt*, word) -> (output_ptr : felt*): + assert [output_ptr] = word + return (output_ptr + 1) +end + +# Array right fold: computes the following: +# callback(callback(... callback(value, a[n-1]) ..., a[1]), a[0]) +# Arguments: +# value - the initial value. +# array - a pointer to an array. +# elm_size - the size of an element in the array. +# n_elms - the number of elements in the array. +# callback - a function pointer to the callback. Expected signature: (felt, T*) -> felt. +func array_rfold(value, array : felt*, n_elms, elm_size, callback) -> (res): + if n_elms == 0: + return (value) + end + + [ap] = value; ap++ + [ap] = array; ap++ + call abs callback + # We use ..., since we use the result of callback as the value for the next iteration. + array_rfold( + ..., array=array + elm_size, n_elms=n_elms - 1, elm_size=elm_size, callback=callback) + return (...) +end + +# Serializes an array of objects to output_ptr, and returns the pointer to the next output cell. +# The format is: len(array) || callback(a[0]) || ... || callback(a[n-1]) . +# Arguments: +# output_ptr - the pointer to serialize to. +# array - a pointer to an array. +# elm_size - the size of an element in the array. +# n_elms - the number of elements in the array. +# callback - a function pointer to the serialize function of a single element. +# Expected signature: (felt, T*) -> felt. +func serialize_array(output_ptr : felt*, array : felt*, n_elms, elm_size, callback) -> ( + output_ptr : felt*): + let (output_ptr) = serialize_word(output_ptr, n_elms) + let (output_ptr : felt*) = array_rfold( + value=cast(output_ptr, felt), + array=array, + n_elms=n_elms, + elm_size=elm_size, + callback=callback) + return (output_ptr) +end diff --git a/src/starkware/cairo/common/signature.cairo b/src/starkware/cairo/common/signature.cairo new file mode 100644 index 00000000..2f949ffe --- /dev/null +++ b/src/starkware/cairo/common/signature.cairo @@ -0,0 +1,15 @@ +from starkware.cairo.common.cairo_builtins import SignatureBuiltin + +# Verifies that the prover knows a signature of the given public_key on the given message. +# +# Prover assumption: (signature_r, signature_s) is a valid signature for the given public_key +# on the given message. +func verify_ecdsa_signature( + ecdsa_ptr : SignatureBuiltin*, message, public_key, signature_r, signature_s) -> ( + ecdsa_ptr : SignatureBuiltin*): + %{ ecdsa_builtin.add_signature(ids.ecdsa_ptr.address_, (ids.signature_r, ids.signature_s)) %} + assert ecdsa_ptr.message = message + assert ecdsa_ptr.pub_key = public_key + + return (ecdsa_ptr=ecdsa_ptr + SignatureBuiltin.SIZE) +end diff --git a/src/starkware/cairo/lang/CMakeLists.txt b/src/starkware/cairo/lang/CMakeLists.txt new file mode 100644 index 00000000..442b4de6 --- /dev/null +++ b/src/starkware/cairo/lang/CMakeLists.txt @@ -0,0 +1,25 @@ +add_subdirectory(builtins) +add_subdirectory(compiler) +add_subdirectory(scripts) +add_subdirectory(tracer) +add_subdirectory(vm) + +python_venv(cairo_lang_venv + PYTHON python3.7 + LIBS + cairo_common_lib + cairo_compile_lib + cairo_run_lib + cairo_script_lib + ${CAIRO_LANG_VENV_ADDITIONAL_LIBS} +) + +python_lib(cairo_instances_lib + PREFIX starkware/cairo/lang + + FILES + instances.py + + LIBS + cairo_run_builtins_lib +) diff --git a/src/starkware/cairo/lang/MANIFEST.in b/src/starkware/cairo/lang/MANIFEST.in new file mode 100644 index 00000000..f9bd1455 --- /dev/null +++ b/src/starkware/cairo/lang/MANIFEST.in @@ -0,0 +1 @@ +include requirements.txt diff --git a/src/starkware/cairo/lang/__init__.py b/src/starkware/cairo/lang/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/starkware/cairo/lang/builtins/CMakeLists.txt b/src/starkware/cairo/lang/builtins/CMakeLists.txt new file mode 100644 index 00000000..f59e9224 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/CMakeLists.txt @@ -0,0 +1,47 @@ +python_lib(cairo_run_builtins_lib + PREFIX starkware/cairo/lang/builtins + + FILES + checkpoints/instance_def.py + checkpoints/checkpoints_builtin_runner.py + hash/hash_builtin_runner.py + hash/instance_def.py + range_check/instance_def.py + range_check/range_check_builtin_runner.py + signature/instance_def.py + signature/signature_builtin_runner.py + + LIBS + cairo_relocatable + cairo_vm_lib + starkware_python_utils_lib +) + +python_lib(cairo_run_builtins_test_utils_lib + PREFIX starkware/cairo/lang/builtins + FILES + builtin_runner_test_utils.py + + LIBS + cairo_compile_lib + cairo_run_lib +) + +full_python_test(cairo_run_builtins_test + PREFIX starkware/cairo/lang/builtins + PYTHON python3.7 + TESTED_MODULES starkware/cairo/lang/builtins + + FILES + range_check/range_check_builtin_runner_test.py + signature/signature_builtin_runner_test.py + + LIBS + cairo_common_lib + cairo_compile_lib + cairo_run_lib + cairo_run_builtins_lib + cairo_run_builtins_test_utils_lib + starkware_python_test_utils_lib + pip_pytest +) diff --git a/src/starkware/cairo/lang/builtins/builtin_runner_test_utils.py b/src/starkware/cairo/lang/builtins/builtin_runner_test_utils.py new file mode 100644 index 00000000..e67c239e --- /dev/null +++ b/src/starkware/cairo/lang/builtins/builtin_runner_test_utils.py @@ -0,0 +1,17 @@ +from starkware.cairo.lang.compiler.cairo_compile import compile_cairo +from starkware.cairo.lang.vm.cairo_runner import CairoRunner + +PRIME = 2**251 + 17 * 2**192 + 1 + + +def compile_and_run(code: str): + """ + Compiles the given code and runs it in the VM. + """ + program = compile_cairo(code, PRIME) + runner = CairoRunner(program, layout='small', proof_mode=False) + runner.initialize_segments() + end = runner.initialize_main_entrypoint() + runner.initialize_vm({}) + runner.run_until_pc(end) + runner.end_run() diff --git a/src/starkware/cairo/lang/builtins/checkpoints/checkpoints_builtin_runner.py b/src/starkware/cairo/lang/builtins/checkpoints/checkpoints_builtin_runner.py new file mode 100644 index 00000000..e4d37f6b --- /dev/null +++ b/src/starkware/cairo/lang/builtins/checkpoints/checkpoints_builtin_runner.py @@ -0,0 +1,39 @@ +from typing import List + +from starkware.cairo.lang.vm.builtin_runner import SimpleBuiltinRunner +from starkware.python.math_utils import safe_div + +# Each sample consists of 2 cells (required pc and required fp). +CELLS_PER_SAMPLE = 2 + + +class CheckpointsBuiltinRunner(SimpleBuiltinRunner): + def __init__(self, name: str, included: bool, sample_ratio: int): + self.sample_ratio = sample_ratio + self.samples: List = [] + super().__init__(name, included, sample_ratio, CELLS_PER_SAMPLE) + + def finalize_segments(self, runner): + memory = runner.vm.run_context.memory + memory[self.stop_ptr] = 0 + memory[self.stop_ptr + 1] = 0 + super().finalize_segments(runner) + + def get_used_cells_and_allocated_size(self, runner): + size = self.get_used_cells(runner) + return size, size + + def sample(self, step, pc, fp): + self.samples.append((step, pc, fp)) + + def relocate(self, relocate_value): + self.samples = [tuple(map(relocate_value, sample)) for sample in self.samples] + + def air_private_input(self, runner): + return {self.name: [ + { + 'index': safe_div(step, self.sample_ratio), + 'pc': hex(pc), + 'fp': hex(fp) + } + for step, pc, fp in self.samples]} diff --git a/src/starkware/cairo/lang/builtins/checkpoints/instance_def.py b/src/starkware/cairo/lang/builtins/checkpoints/instance_def.py new file mode 100644 index 00000000..a87b98a5 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/checkpoints/instance_def.py @@ -0,0 +1,8 @@ +import dataclasses + + +@dataclasses.dataclass +class CheckpointsInstanceDef: + # Defines the ratio between the number of steps to the number of samples. + # For every sample_ratio steps, we have one sample. + sample_ratio: int diff --git a/src/starkware/cairo/lang/builtins/hash/hash_builtin_runner.py b/src/starkware/cairo/lang/builtins/hash/hash_builtin_runner.py new file mode 100644 index 00000000..5495a255 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/hash/hash_builtin_runner.py @@ -0,0 +1,81 @@ +from typing import Any, Dict, Optional, Set + +from starkware.cairo.lang.vm.builtin_runner import BuiltinVerifier, SimpleBuiltinRunner +from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue +from starkware.python.math_utils import safe_div + +# Each hash consists of 3 cells (two inputs and one output). +CELLS_PER_HASH = 3 + + +class HashBuiltinRunner(SimpleBuiltinRunner): + def __init__(self, name: str, included: bool, ratio: int, hash_func): + super().__init__(name, included, ratio, CELLS_PER_HASH) + self.hash_func = hash_func + self.stop_ptr: Optional[RelocatableValue] = None + self.verified_addresses: Set[MaybeRelocatable] = set() + + def add_auto_deduction_rules(self, runner): + def rule(vm, addr, verified_addresses): + memory = vm.run_context.memory + if addr.offset % CELLS_PER_HASH != 2: + return + if addr in verified_addresses: + return + if addr - 1 not in memory or addr - 2 not in memory: + return + assert vm.is_integer_value(memory[addr - 2]), \ + f'{self.name} builtin: Expected integer at address {addr - 2}. ' + \ + f'Got: {memory[addr - 2]}.' + assert vm.is_integer_value(memory[addr - 1]), \ + f'{self.name} builtin: Expected integer at address {addr - 1}. ' + \ + f'Got: {memory[addr - 1]}.' + res = self.hash_func(memory[addr - 2], memory[addr - 1]) + verified_addresses.add(addr) + return res + + runner.vm.add_auto_deduction_rule(self.base.segment_index, rule, self.verified_addresses) + + def air_private_input(self, runner) -> Dict[str, Any]: + assert self.base is not None, 'Uninitialized self.base.' + res: Dict[int, Any] = {} + for addr, val in runner.vm_memory.items(): + if not isinstance(addr, RelocatableValue) or \ + addr.segment_index != self.base.segment_index: + continue + idx = addr.offset // CELLS_PER_HASH + typ = addr.offset % CELLS_PER_HASH + if typ == 2: + continue + + assert isinstance(val, int) + res.setdefault(idx, {'index': idx})['x' if typ == 0 else 'y'] = hex(val) + + for index, item in res.items(): + assert 'x' in item, f'Missing first input of {self.name} instance {index}.' + assert 'y' in item, f'Missing second input of {self.name} instance {index}.' + + return {self.name: sorted(res.values(), key=lambda item: item['index'])} + + def get_additional_data(self): + return list(map(RelocatableValue.to_tuple, self.verified_addresses)) + + def extend_additional_data(self, data, relocate_callback): + for addr in data: + self.verified_addresses.add(relocate_callback(RelocatableValue.from_tuple(addr))) + + +class HashBuiltinVerifier(BuiltinVerifier): + def __init__(self, included: bool, ratio): + self.included = included + self.ratio = ratio + + def expected_stack(self, public_input): + if not self.included: + return [], [] + + addresses = public_input.memory_segments['pedersen'] + max_size = CELLS_PER_HASH * safe_div(public_input.n_steps, self.ratio) + assert 0 <= addresses.begin_addr <= addresses.stop_ptr <= \ + addresses.begin_addr + max_size < 2**64 + return [addresses.begin_addr], [addresses.stop_ptr] diff --git a/src/starkware/cairo/lang/builtins/hash/instance_def.py b/src/starkware/cairo/lang/builtins/hash/instance_def.py new file mode 100644 index 00000000..2d0dfc08 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/hash/instance_def.py @@ -0,0 +1,20 @@ +import dataclasses +from typing import Optional + + +@dataclasses.dataclass +class PedersenInstanceDef: + # Defines the ratio between the number of steps to the number of pedersen instances. + # For every ratio steps, we have one instance. + ratio: int + + # Split to this many different components - for optimization. + repetitions: int + + # Size of hash. + element_height: int + element_bits: int + # Number of inputs for hash. + n_inputs: int + # The upper bound on the hash inputs. If None, the upper bound is 2^element_bits. + hash_limit: Optional[int] = None diff --git a/src/starkware/cairo/lang/builtins/range_check/instance_def.py b/src/starkware/cairo/lang/builtins/range_check/instance_def.py new file mode 100644 index 00000000..6a1319ba --- /dev/null +++ b/src/starkware/cairo/lang/builtins/range_check/instance_def.py @@ -0,0 +1,11 @@ +import dataclasses + + +@dataclasses.dataclass +class RangeCheckInstanceDef: + # Defines the ratio between the number of steps to the number of range check instances. + # For every ratio steps, we have one instance. + ratio: int + # Number of 16-bit range checks that will be used for each instance of the builtin. + # For example, n_parts=8 defines the range [0, 2^128). + n_parts: int diff --git a/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner.py b/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner.py new file mode 100644 index 00000000..c4848182 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner.py @@ -0,0 +1,84 @@ +from typing import Any, Dict, Optional, Tuple + +from starkware.cairo.lang.vm.builtin_runner import BuiltinVerifier, SimpleBuiltinRunner +from starkware.cairo.lang.vm.relocatable import RelocatableValue +from starkware.python.math_utils import safe_div + + +class RangeCheckBuiltinRunner(SimpleBuiltinRunner): + def __init__(self, included: bool, ratio, inner_rc_bound, n_parts): + super().__init__('range_check', included, ratio) + self.inner_rc_bound = inner_rc_bound + self.bound = inner_rc_bound ** n_parts + self.n_parts = n_parts + + def add_validation_rules(self, runner): + def rule(memory, addr): + value = memory[addr] + # The range check builtin asserts that 0 <= value < BOUND. + # For example, if the layout uses 8 16-bit range-checks per instance, + # bound will be 2**(16 * 8) = 2**128. + assert 0 <= value < self.bound, \ + f'Value {value}, in range check builtin {addr - self.base}, is out of range ' \ + f'[0, {self.bound}).' + return {addr} + + runner.vm.add_validation_rule(self.base.segment_index, rule) + + def air_private_input(self, runner) -> Dict[str, Any]: + assert self.base is not None, 'Uninitialized self.base.' + res: Dict[int, Any] = {} + for addr, val in runner.vm_memory.items(): + if not isinstance(addr, RelocatableValue) or \ + addr.segment_index != self.base.segment_index: + continue + idx = addr.offset + + assert isinstance(val, int) + res[idx] = {'index': idx, 'value': hex(val)} + + return {'range_check': sorted(res.values(), key=lambda item: item['index'])} + + def get_range_check_usage(self, runner) -> Optional[Tuple[int, int]]: + assert self.base is not None, 'Uninitialized self.base.' + rc_min = None + rc_max = None + for addr, val in runner.vm_memory.items(): + if not isinstance(addr, RelocatableValue) or \ + addr.segment_index != self.base.segment_index: + continue + + # Split val into n_parts parts. + for _ in range(self.n_parts): + part_val = val % self.inner_rc_bound + + if rc_min is None: + rc_min = rc_max = part_val + else: + rc_min = min(rc_min, part_val) + rc_max = max(rc_max, part_val) + val //= self.inner_rc_bound + if rc_min is None or rc_max is None: + return None + return rc_min, rc_max + + def get_used_perm_range_check_units(self, runner) -> int: + used_cells, _ = self.get_used_cells_and_allocated_size(runner) + # Each cell in the range check segment requires n_parts range check units. + return used_cells * self.n_parts + + +class RangeCheckBuiltinVerifier(BuiltinVerifier): + def __init__(self, included: bool, ratio): + self.included = included + self.ratio = ratio + + def expected_stack(self, public_input): + if not self.included: + return [], [] + + addresses = public_input.memory_segments['range_check'] + max_size = safe_div(public_input.n_steps, self.ratio) + assert 0 <= addresses.begin_addr <= addresses.stop_ptr <= \ + addresses.begin_addr + max_size < 2**64 + return [addresses.begin_addr], [addresses.stop_ptr] diff --git a/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner_test.py b/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner_test.py new file mode 100644 index 00000000..46b76624 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner_test.py @@ -0,0 +1,25 @@ +import pytest + +from starkware.cairo.lang.builtins.builtin_runner_test_utils import PRIME, compile_and_run +from starkware.cairo.lang.vm.vm import VmException + + +def test_validation_rules(): + CODE_FORMAT = """ +%builtins range_check + +func main(range_check_ptr) -> (range_check_ptr): + assert [range_check_ptr] = {value} + return (range_check_ptr=range_check_ptr + 1) +end +""" + + # Test valid values. + compile_and_run(CODE_FORMAT.format(value=0)) + compile_and_run(CODE_FORMAT.format(value=1)) + + with pytest.raises( + VmException, + match=f'Value {PRIME - 1}, in range check builtin 0, is out of range ' + r'\[0, {bound}\)'.format(bound=2**128)): + compile_and_run(CODE_FORMAT.format(value=-1)) diff --git a/src/starkware/cairo/lang/builtins/signature/instance_def.py b/src/starkware/cairo/lang/builtins/signature/instance_def.py new file mode 100644 index 00000000..02e4666f --- /dev/null +++ b/src/starkware/cairo/lang/builtins/signature/instance_def.py @@ -0,0 +1,15 @@ +import dataclasses + + +@dataclasses.dataclass +class EcdsaInstanceDef: + # Defines the ratio between the number of steps to the number of ECDSA instances. + # For every ratio steps, we have one instance. + ratio: int + + # Split to this many different components - for optimization. + repetitions: int + + # Size of hash. + height: int + n_hash_bits: int diff --git a/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner.py b/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner.py new file mode 100644 index 00000000..569b6e2a --- /dev/null +++ b/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner.py @@ -0,0 +1,107 @@ +from typing import Any, Dict + +from starkware.cairo.lang.vm.builtin_runner import BuiltinVerifier, SimpleBuiltinRunner +from starkware.cairo.lang.vm.relocatable import RelocatableValue +from starkware.python.math_utils import safe_div + +# Each signature consists of 2 cells (a public key and a message). +CELLS_PER_SIGNATURE = 2 + + +class SignatureBuiltinRunner(SimpleBuiltinRunner): + def __init__(self, name: str, included: bool, ratio, process_signature, verify_signature): + """ + 'process_signature' is a function that takes signatures as saved in 'signatures' and + returns a dict representing the signature in the format expected by the component used by + the runner. + It may also assert that the signature is valid. + """ + super().__init__(name, included, ratio, CELLS_PER_SIGNATURE) + self.process_signature = process_signature + self.verify_signature = verify_signature + + # A dict of address -> signature. + self.signatures: Dict = {} + + def add_validation_rules(self, runner): + def rule(memory, addr): + # A signature builtin instance consists of a pair of public key and message. + if addr.offset % CELLS_PER_SIGNATURE == 0 and addr + 1 in memory: + pubkey_addr = addr + msg_addr = addr + 1 + elif addr.offset % CELLS_PER_SIGNATURE == 1 and addr - 1 in memory: + pubkey_addr = addr - 1 + msg_addr = addr + else: + return set() + + pubkey = memory[pubkey_addr] + msg = memory[msg_addr] + assert isinstance(pubkey, int), \ + f'ECDSA builtin: Expected public key at address {pubkey_addr} to be an integer. ' \ + f'Got: {pubkey}.' + assert isinstance(msg, int), \ + f'ECDSA builtin: Expected message hash at address {msg_addr} to be an integer. ' \ + f'Got: {msg}.' + assert pubkey_addr in self.signatures, \ + f'Signature hint is missing for ECDSA builtin at address {pubkey_addr}. ' \ + "Add it using 'ecdsa_builtin.add_signature'." + + signature = self.signatures[pubkey_addr] + assert self.verify_signature(pubkey, msg, signature), \ + f'Signature {signature}, is invalid, with respect to the public key {pubkey}, ' \ + f'and the message hash {msg}.' + return {pubkey_addr, msg_addr} + + runner.vm.add_validation_rule(self.base.segment_index, rule) + + def air_private_input(self, runner) -> Dict[str, Any]: + res: Dict[int, Any] = {} + for (addr, signature) in self.signatures.items(): + addr_offset = addr - self.base + idx = safe_div(addr_offset, CELLS_PER_SIGNATURE) + pubkey = runner.vm_memory[addr] + msg = runner.vm_memory[addr + 1] + res[idx] = { + 'index': idx, + 'pubkey': hex(pubkey), + 'msg': hex(msg), + 'signature_input': self.process_signature(pubkey, msg, signature), + } + + return {self.name: sorted(res.values(), key=lambda item: item['index'])} + + def add_signature(self, addr, signature): + """ + This function should be used in Cairo hints. + """ + assert isinstance(addr, RelocatableValue), \ + f'Expected memory address to be relocatable value. Found: {addr}.' + assert addr.offset % CELLS_PER_SIGNATURE == 0, \ + f'Signature hint must point to the public key cell, not {addr}.' + self.signatures[addr] = signature + + def get_additional_data(self): + return [ + (RelocatableValue.to_tuple(addr), signature) + for addr, signature in self.signatures.items()] + + def extend_additional_data(self, data, relocate_callback): + for addr, signature in data: + self.signatures[relocate_callback(RelocatableValue.from_tuple(addr))] = signature + + +class SignatureBuiltinVerifier(BuiltinVerifier): + def __init__(self, included: bool, ratio): + self.included = included + self.ratio = ratio + + def expected_stack(self, public_input): + if not self.included: + return [], [] + + addresses = public_input.memory_segments['signature'] + max_size = safe_div(public_input.n_steps, self.ratio) * CELLS_PER_SIGNATURE + assert 0 <= addresses.begin_addr <= addresses.stop_ptr <= \ + addresses.begin_addr + max_size < 2**64 + return [addresses.begin_addr], [addresses.stop_ptr] diff --git a/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner_test.py b/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner_test.py new file mode 100644 index 00000000..a7c656e9 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner_test.py @@ -0,0 +1,153 @@ +import dataclasses +from types import SimpleNamespace +from typing import Optional + +import pytest + +from starkware.cairo.lang.builtins.builtin_runner_test_utils import compile_and_run +from starkware.cairo.lang.vm.vm import VmException +from starkware.python.test_utils import maybe_raises + + +@dataclasses.dataclass +class SignatureCodeSections: + """ + Code sections relevant for using the signature builtin. + See code snippet structure below. + """ + hint: str + write_pubkey: str + write_msg: str + + +@dataclasses.dataclass +class SignatureExample: + code_sections: SignatureCodeSections + # Error message received by running the example code, in case there is any. + error_msg: Optional[str] + + +# Constants used for creating a code snippet using the signature builtin. +# See signature_builtin_runner_test.py. +SIG_PTR = 'ecdsa_ptr' +formats = SimpleNamespace( + hint_code_format='%{{ ecdsa_builtin.add_signature({addr}, {signature}) %}}', + pubkey_code_format=f'assert [{SIG_PTR} + SignatureBuiltin.pub_key] = {{pubkey}}', + msg_code_format=f'assert [{SIG_PTR} + SignatureBuiltin.message] = {{msg}}', +) + +# The address is used inside a hint. +VALID_ADDR = f'ids.{SIG_PTR}' +VALID_SIG = ( + 3086480810278599376317923499561306189851900463386393948998357832163236918254, + 598673427589502599949712887611119751108407514580626464031881322743364689811) +constants = SimpleNamespace( + valid_addr=VALID_ADDR, + invalid_addr=VALID_ADDR + ' + 1', + valid_sig=VALID_SIG, + invalid_sig=(VALID_SIG[0] + 1, VALID_SIG[1]), + valid_pubkey=1735102664668487605176656616876767369909409133946409161569774794110049207117, + valid_msg=2718, + invalid_pubkey_or_msg=SIG_PTR, +) + + +class SignatureTest: + """ + Aggregates test cases for the signature builtin runner. + A valid test case is added at initialization and further test cases are added based on the + valid case. + """ + + def __init__(self): + self.test_cases = {'valid': SignatureExample( + error_msg=None, + code_sections=SignatureCodeSections( + hint=formats.hint_code_format.format( + addr=constants.valid_addr, signature=constants.valid_sig), + write_pubkey=formats.pubkey_code_format.format( + pubkey=constants.valid_pubkey), + write_msg=formats.msg_code_format.format(msg=constants.valid_msg), + ) + )} + + def add_test_case(self, name: str, error_msg: Optional[str], **code_section_changes): + """ + Adds a new test case with the given error message, based on the valid case and the given + changes to it. + """ + self.test_cases[name] = SignatureExample( + code_sections=dataclasses.replace( + self.test_cases['valid'].code_sections, **code_section_changes), + error_msg=error_msg, + ) + + def get_test_cases(self): + return self.test_cases + + +# Signature code snippet structure. +CODE = """ +%builtins ecdsa +from starkware.cairo.common.cairo_builtins import SignatureBuiltin + +func main(ecdsa_ptr) -> (ecdsa_ptr): + {hint} + {write_pubkey} + {write_msg} + return(ecdsa_ptr=ecdsa_ptr + SignatureBuiltin.SIZE) +end +""" + +test = SignatureTest() +test.add_test_case( + name='invalid_signature_address', + error_msg='Signature hint must point to the public key cell, not 2:1.', + hint=formats.hint_code_format.format( + addr=constants.invalid_addr, signature=constants.valid_sig), +) + +test.add_test_case( + name='invalid_signature', + error_msg=( + r'Signature .* is invalid, with respect to the public key ' + '1735102664668487605176656616876767369909409133946409161569774794110049207117, ' + 'and the message hash 2718.'), + hint=formats.hint_code_format.format( + addr=constants.valid_addr, signature=constants.invalid_sig), +) + +test.add_test_case( + name='invalid_public_key', + error_msg='ECDSA builtin: Expected public key at address 2:0 to be an integer. Got: 2:0.', + write_pubkey=formats.pubkey_code_format.format(pubkey=constants.invalid_pubkey_or_msg), +) + +test.add_test_case( + name='invalid_message', + error_msg='ECDSA builtin: Expected message hash at address 2:1 to be an integer. Got: 2:0.', + write_msg=formats.msg_code_format.format(msg=constants.invalid_pubkey_or_msg), +) + +test.add_test_case( + name='missing_hint', + error_msg=( + 'Signature hint is missing for ECDSA builtin at address 2:0. ' + "Add it using 'ecdsa_builtin.add_signature'."), + hint='', +) + +# Missing public key or message would not cause a runtime error, but would fail the prover. +test.add_test_case(name='missing_public_key', error_msg=None, write_pubkey='') +test.add_test_case(name='missing_message', error_msg=None, write_msg='') + +test_cases = test.get_test_cases() + + +@pytest.mark.parametrize('case', test_cases.values(), ids=test_cases.keys()) +def test_validation_rules(case): + code = CODE.format(**dataclasses.asdict(case.code_sections)) + with maybe_raises( + expected_exception=VmException, error_message=case.error_msg, + escape_error_message=False): + compile_and_run(code) diff --git a/src/starkware/cairo/lang/compiler/CMakeLists.txt b/src/starkware/cairo/lang/compiler/CMakeLists.txt new file mode 100644 index 00000000..b7ddb338 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/CMakeLists.txt @@ -0,0 +1,132 @@ +python_lib(cairo_compile_lib + PREFIX starkware/cairo/lang/compiler + FILES + __init__.py + assembler.py + ast/__init__.py + ast/arguments.py + ast/bool_expr.py + ast/cairo_types.py + ast/code_elements.py + ast/expr.py + ast/formatting_utils.py + ast/instructions.py + ast/module.py + ast/node.py + ast/notes.py + ast/rvalue.py + ast/types.py + ast/visitor.py + cairo_compile.py + cairo_format.py + cairo.ebnf + constants.py + const_expr_checker.py + debug_info.py + encode.py + error_handling.py + expression_evaluator.py + expression_simplifier.py + expression_transformer.py + fields.py + identifier_definition.py + identifier_manager.py + identifier_manager_field.py + identifier_utils.py + import_loader.py + instruction_builder.py + instruction.py + location_utils.py + module_reader.py + parser_transformer.py + parser.py + preprocessor/compound_expressions.py + preprocessor/flow.py + preprocessor/identifier_collector.py + preprocessor/local_variables.py + preprocessor/preprocessor_error.py + preprocessor/preprocessor_utils.py + preprocessor/preprocessor.py + preprocessor/reg_tracking.py + program.py + references.py + scoped_name.py + substitute_identifiers.py + type_system_visitor.py + + LIBS + starkware_expression_string_lib + starkware_python_utils_lib + pip_marshmallow_dataclass + pip_marshmallow_enum + pip_marshmallow_oneofschema + pip_marshmallow + pip_lark_parser +) + +python_exe(cairo_compile_exe + VENV cairo_lang_venv + MODULE starkware.cairo.lang.compiler.cairo_compile +) + +python_venv(cairo_format_venv + PYTHON python3.7 + LIBS + cairo_compile_lib +) + +python_exe(cairo_format + VENV cairo_format_venv + MODULE starkware.cairo.lang.compiler.cairo_format +) + +python_lib(cairo_compile_test_utils_lib + PREFIX starkware/cairo/lang/compiler + FILES + preprocessor/preprocessor_test_utils.py + test_utils.py + + LIBS + cairo_compile_lib + pip_pytest +) + +full_python_test(cairo_compile_test + PREFIX starkware/cairo/lang/compiler + PYTHON python3.7 + TESTED_MODULES starkware/cairo/lang/compiler + + FILES + assembler_test.py + ast_objects_test.py + ast/formatting_utils_test.py + cairo_compile_test.py + encode_test.py + error_handling_test.py + expression_evaluator_test.py + expression_simplifier_test.py + identifier_definition_test.py + identifier_manager_field_test.py + identifier_manager_test.py + identifier_utils_test.py + import_loader_test.py + instruction_builder_test.py + instruction_test.py + module_reader_test.py + parser_errors_test.py + parser_test.py + preprocessor/compound_expressions_test.py + preprocessor/flow_test.py + preprocessor/identifier_collector_test.py + preprocessor/local_variables_test.py + preprocessor/preprocessor_test.py + preprocessor/reg_tracking_test.py + references_test.py + scoped_name_test.py + type_system_visitor_test.py + + LIBS + cairo_compile_lib + cairo_compile_test_utils_lib + pip_pytest +) diff --git a/src/starkware/cairo/lang/compiler/__init__.py b/src/starkware/cairo/lang/compiler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/starkware/cairo/lang/compiler/assembler.py b/src/starkware/cairo/lang/compiler/assembler.py new file mode 100644 index 00000000..65a05c43 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/assembler.py @@ -0,0 +1,49 @@ +from typing import Dict, List + +from starkware.cairo.lang.compiler.debug_info import DebugInfo, HintLocation, InstructionLocation +from starkware.cairo.lang.compiler.encode import encode_instruction +from starkware.cairo.lang.compiler.instruction_builder import build_instruction +from starkware.cairo.lang.compiler.preprocessor.preprocessor import PreprocessedProgram +from starkware.cairo.lang.compiler.program import CairoHint, Program +from starkware.cairo.lang.compiler.scoped_name import ScopedName + + +def assemble( + preprocessed_program: PreprocessedProgram, main_scope: ScopedName = ScopedName(), + add_debug_info: bool = False, file_contents_for_debug_info: Dict[str, str] = {}) -> Program: + data: List[int] = [] + hints: Dict[int, CairoHint] = {} + debug_info = DebugInfo(instruction_locations={}, file_contents=file_contents_for_debug_info) \ + if add_debug_info else None + + for inst in preprocessed_program.instructions: + if inst.hint: + hints[len(data)] = CairoHint( + code=inst.hint.hint_code, + accessible_scopes=inst.accessible_scopes, + flow_tracking_data=inst.flow_tracking_data) + if debug_info is not None and inst.instruction.location is not None: + hint_location = None + if inst.hint is not None and inst.hint.location is not None: + hint_location = HintLocation( + location=inst.hint.location, + n_prefix_newlines=inst.hint.n_prefix_newlines, + ) + debug_info.instruction_locations[len(data)] = \ + InstructionLocation( + inst=inst.instruction.location, + hint=hint_location, + accessible_scopes=inst.accessible_scopes, + flow_tracking_data=inst.flow_tracking_data) + data += [word for word in encode_instruction( + build_instruction(inst.instruction), prime=preprocessed_program.prime)] + + return Program( + prime=preprocessed_program.prime, + data=data, + hints=hints, + main_scope=main_scope, + identifiers=preprocessed_program.identifiers, + builtins=preprocessed_program.builtins, + reference_manager=preprocessed_program.reference_manager, + debug_info=debug_info) diff --git a/src/starkware/cairo/lang/compiler/assembler_test.py b/src/starkware/cairo/lang/compiler/assembler_test.py new file mode 100644 index 00000000..c8ed75b0 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/assembler_test.py @@ -0,0 +1,53 @@ +import pytest + +from starkware.cairo.lang.compiler.identifier_definition import ConstDefinition, LabelDefinition +from starkware.cairo.lang.compiler.identifier_manager import ( + IdentifierManager, MissingIdentifierError) +from starkware.cairo.lang.compiler.preprocessor.flow import ReferenceManager +from starkware.cairo.lang.compiler.program import Program +from starkware.cairo.lang.compiler.scoped_name import ScopedName + + +def test_main_scope(): + identifiers = IdentifierManager.from_dict({ + ScopedName.from_string('a.b'): ConstDefinition(value=1), + ScopedName.from_string('x.y.z'): ConstDefinition(value=2), + }) + reference_manager = ReferenceManager() + + program = Program( + prime=0, data=[], hints={}, builtins=[], main_scope=ScopedName.from_string('a'), + identifiers=identifiers, reference_manager=reference_manager) + + # Check accessible identifiers. + assert program.get_identifier('b', ConstDefinition) + + # Ensure inaccessible identifiers. + with pytest.raises(MissingIdentifierError, match="Unknown identifier 'a'."): + program.get_identifier('a.b', ConstDefinition) + + with pytest.raises(MissingIdentifierError, match="Unknown identifier 'x'."): + program.get_identifier('x.y', ConstDefinition) + + with pytest.raises(MissingIdentifierError, match="Unknown identifier 'y'."): + program.get_identifier('y', ConstDefinition) + + +def test_program_start_property(): + identifiers = IdentifierManager.from_dict({ + ScopedName.from_string('some.main.__start__'): LabelDefinition(3), + }) + reference_manager = ReferenceManager() + main_scope = ScopedName.from_string('some.main') + + # The label __start__ is in identifiers. + program = Program( + prime=0, data=[], hints={}, builtins=[], main_scope=main_scope, identifiers=identifiers, + reference_manager=reference_manager) + assert program.start == 3 + + # The label __start__ is not in identifiers. + program = Program( + prime=0, data=[], hints={}, builtins=[], main_scope=main_scope, + identifiers=IdentifierManager(), reference_manager=reference_manager) + assert program.start == 0 diff --git a/src/starkware/cairo/lang/compiler/ast/__init__.py b/src/starkware/cairo/lang/compiler/ast/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/starkware/cairo/lang/compiler/ast/arguments.py b/src/starkware/cairo/lang/compiler/ast/arguments.py new file mode 100644 index 00000000..edf69891 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/arguments.py @@ -0,0 +1,23 @@ +import dataclasses +from typing import List, Optional, Sequence + +from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.ast.notes import NoteListField, Notes +from starkware.cairo.lang.compiler.ast.types import TypedIdentifier +from starkware.cairo.lang.compiler.error_handling import Location + + +@dataclasses.dataclass +class IdentifierList(AstNode): + identifiers: List[TypedIdentifier] + notes: List[Notes] = NoteListField # type: ignore + location: Optional[Location] = LocationField + + def get_particles(self): + for note in self.notes: + note.assert_no_comments() + return [x.format() for x in self.identifiers] + + def get_children(self) -> Sequence[Optional[AstNode]]: + return self.identifiers diff --git a/src/starkware/cairo/lang/compiler/ast/bool_expr.py b/src/starkware/cairo/lang/compiler/ast/bool_expr.py new file mode 100644 index 00000000..ca6769be --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/bool_expr.py @@ -0,0 +1,22 @@ +import dataclasses +from typing import Optional, Sequence + +from starkware.cairo.lang.compiler.ast.expr import Expression +from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.error_handling import Location + + +@dataclasses.dataclass +class BoolExpr(AstNode): + a: Expression + b: Expression + eq: bool + location: Optional[Location] = LocationField + + def get_particles(self): + relation = '==' if self.eq else '!=' + return [f'{self.a.format()} {relation} ', self.b.format()] + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.a, self.b] diff --git a/src/starkware/cairo/lang/compiler/ast/cairo_types.py b/src/starkware/cairo/lang/compiler/ast/cairo_types.py new file mode 100644 index 00000000..ce85a75a --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/cairo_types.py @@ -0,0 +1,69 @@ +import dataclasses +from abc import abstractmethod +from typing import Optional, Sequence + +from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.error_handling import Location +from starkware.cairo.lang.compiler.scoped_name import ScopedName + + +class CairoType(AstNode): + location: Optional[Location] + + @abstractmethod + def format(self) -> str: + """ + Returns a representation of the type as a string. + """ + + def get_pointer_type(self) -> 'CairoType': + """ + Returns a type of a pointer to the current type. + """ + return TypePointer(pointee=self, location=self.location) + + +@dataclasses.dataclass +class TypeFelt(CairoType): + location: Optional[Location] = LocationField + + def format(self): + return 'felt' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class TypePointer(CairoType): + pointee: CairoType + location: Optional[Location] = LocationField + + def format(self): + return f'{self.pointee.format()}*' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.pointee] + + +@dataclasses.dataclass +class TypeStruct(CairoType): + scope: ScopedName + # Indicates whether scope refers to the fully resolved name. + is_fully_resolved: bool + location: Optional[Location] = LocationField + + def format(self): + return str(self.scope) + + @property + def resolved_scope(self): + """ + Verifies that is_fully_resolved=True and returns scope. + """ + assert self.is_fully_resolved, 'Type is expected to be fully resolved at this point.' + return self.scope + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] diff --git a/src/starkware/cairo/lang/compiler/ast/code_elements.py b/src/starkware/cairo/lang/compiler/ast/code_elements.py new file mode 100644 index 00000000..b3b8d4c8 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/code_elements.py @@ -0,0 +1,573 @@ +import dataclasses +from abc import abstractmethod +from typing import List, Optional, Sequence + +from starkware.cairo.lang.compiler.ast.arguments import IdentifierList +from starkware.cairo.lang.compiler.ast.bool_expr import BoolExpr +from starkware.cairo.lang.compiler.ast.expr import ArgListItem, Expression, ExprIdentifier +from starkware.cairo.lang.compiler.ast.formatting_utils import ( + INDENTATION, LocationField, ParticleFormattingConfig, create_particle_sublist, + particles_in_lines) +from starkware.cairo.lang.compiler.ast.instructions import InstructionAst +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.ast.rvalue import Rvalue, RvalueCall, RvalueFuncCall +from starkware.cairo.lang.compiler.ast.types import TypedIdentifier +from starkware.cairo.lang.compiler.error_handling import Location +from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.python.utils import indent + + +class CodeElement(AstNode): + @abstractmethod + def format(self, allowed_line_length): + """ + Formats the code element, without exceeding a line length of `allowed_line_length`. + """ + + +@dataclasses.dataclass +class CodeElementInstruction(CodeElement): + instruction: InstructionAst + + def get_particles(self): + return [self.instruction.format()] + + def format(self, allowed_line_length): + return self.instruction.format() + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.instruction] + + +@dataclasses.dataclass +class CodeElementConst(CodeElement): + identifier: ExprIdentifier + expr: Expression + + def format(self, allowed_line_length): + return f'const {self.identifier.format()} = {self.expr.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.identifier, self.expr] + + +@dataclasses.dataclass +class CodeElementMember(CodeElement): + typed_identifier: TypedIdentifier + expr: Expression + + def format(self, allowed_line_length): + return f'member {self.typed_identifier.format()} = {self.expr.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.typed_identifier, self.expr] + + +@dataclasses.dataclass +class CodeElementReference(CodeElement): + typed_identifier: TypedIdentifier + expr: Expression + + def format(self, allowed_line_length): + return f'let {self.typed_identifier.format()} = {self.expr.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.typed_identifier, self.expr] + + +@dataclasses.dataclass +class CodeElementLocalVariable(CodeElement): + """ + Represents a statement of the form: + local x [: expr_type] = [expr] + + Both the expr_type and the initialization expr are optional. + """ + typed_identifier: TypedIdentifier + expr: Optional[Expression] + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + assignment = '' if self.expr is None else f' = {self.expr.format()}' + return f'local {self.typed_identifier.format()}{assignment}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.typed_identifier, self.expr] + + +@dataclasses.dataclass +class CodeElementTemporaryVariable(CodeElement): + """ + Represents a statement of the form: + tempvar x = expr. + """ + typed_identifier: TypedIdentifier + expr: Expression + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + return f'tempvar {self.typed_identifier.format()} = {self.expr.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.typed_identifier, self.expr] + + +@dataclasses.dataclass +class CodeElementCompoundAssertEq(CodeElement): + """ + Represents the statement "assert a = b" for two (compound) expressions a, b. + Unlike AssertEqInstruction, a CodeElementCompoundAssertEq may translate to a few instructions + to deal with expressions which contain more than one operation. + """ + a: Expression + b: Expression + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + return f'assert {self.a.format()} = {self.b.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.a, self.b] + + +@dataclasses.dataclass +class CodeElementStaticAssert(CodeElement): + a: Expression + b: Expression + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + return f'static_assert {self.a.format()} == {self.b.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.a, self.b] + + +@dataclasses.dataclass +class CodeElementReturn(CodeElement): + """ + Represents a statement of the form: + return ([ident=]expr, ...). + """ + exprs: List[ArgListItem] + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + expr_codes = [x.format() for x in self.exprs] + particles = ['return (', create_particle_sublist(expr_codes, ')')] + + return particles_in_lines( + particles=particles, + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, + line_indent=INDENTATION, + one_per_line=True)) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return self.exprs + + +@dataclasses.dataclass +class CodeElementFuncCall(CodeElement): + """ + Represents a statement of the form: + func_ident([ident=]expr, ...). + """ + func_call: RvalueFuncCall + + def get_particles(self): + return self.func_call.get_particles() + + def format(self, allowed_line_length): + return self.func_call.format(allowed_line_length) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.func_call] + + +@dataclasses.dataclass +class CodeElementReturnValueReference(CodeElement): + """ + Represents one of the references below. + let x [: type] = func(...) + let x [: type] = call func + let x [: type] = call rel 5 + where: + 'x [: type]' is the 'typed_identifier' + 'func(...)' is the 'func_call'. + """ + typed_identifier: TypedIdentifier + func_call: RvalueCall + + def format(self, allowed_line_length): + call_particles = self.func_call.get_particles() + first_particle = f'let {self.typed_identifier.format()} = ' + call_particles[0] + + return particles_in_lines( + particles=[first_particle] + call_particles[1:], + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, + line_indent=INDENTATION, + one_per_line=True)) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.typed_identifier, self.func_call] + + +@dataclasses.dataclass +class CodeElementUnpackBinding(CodeElement): + """ + Represents return value unpacking statement of the form: + let (a, b, c) = func(...) + where: + '(a, b, c)' is the 'unpacking_list' + 'func(...)' is the 'rvalue'. + """ + unpacking_list: IdentifierList + rvalue: Rvalue + + def format(self, allowed_line_length): + particles = self.rvalue.get_particles() + + end_particle = ') = ' + particles[0] + particles = ['let ('] + \ + create_particle_sublist(self.unpacking_list.get_particles(), end_particle) + \ + particles[1:] + + return particles_in_lines( + particles=particles, + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, + line_indent=INDENTATION, + one_per_line=True)) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.unpacking_list, self.rvalue] + + +@dataclasses.dataclass +class CodeElementLabel(CodeElement): + identifier: ExprIdentifier + + def format(self, allowed_line_length): + return f'{self.identifier.format()}:' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.identifier] + + +@dataclasses.dataclass +class CodeElementHint(CodeElement): + hint_code: str + # The number of new lines following the "%{" symbol. + n_prefix_newlines: int + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + if self.hint_code == '': + return '%{\n%}' + if '\n' not in self.hint_code: + # One liner. + return f'%{{ {self.hint_code} %}}' + code = indent(self.hint_code, INDENTATION) + return f'%{{\n{code}\n%}}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class CodeElementEmptyLine(CodeElement): + def format(self, allowed_line_length): + return '' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class CommentedCodeElement(AstNode): + code_elm: CodeElement + comment: Optional[str] + + def format(self, allowed_line_length): + elm_str = self.code_elm.format(allowed_line_length=allowed_line_length) + comment_str = f'#{self.comment}' if self.comment is not None else '' + separator = ' ' if elm_str != '' and comment_str != '' else '' + return elm_str + separator + comment_str.rstrip() + + def fix_comment_spaces(self, allow_additional_comment_spaces: bool): + """ + Comments should start with exactly one space after '#' except for some cases (in which + allow_additional_comment_spaces=True). + Returns a copy of this instance with a fixed comment. + """ + comment = self.comment + + if comment is None: + return self + + if not allow_additional_comment_spaces: + comment = comment.strip() + if not comment.startswith(' '): + comment = ' ' + comment + + return CommentedCodeElement(code_elm=self.code_elm, comment=comment) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.code_elm] + + +@dataclasses.dataclass +class CodeBlock(AstNode): + code_elements: List[CommentedCodeElement] + + def format(self, allowed_line_length): + code_elements = remove_redundant_empty_lines(self.code_elements) + code_elements = add_empty_lines_before_labels(code_elements) + code_elements = fix_comment_spaces(code_elements) + + return ''.join(f'{code_elm.format(allowed_line_length)}\n' for code_elm in code_elements) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return self.code_elements + + +@dataclasses.dataclass +class CodeElementScoped(CodeElement): + """ + Represents a list of code elements that should be handled inside a scope. + This class does not appear naturally in the parsed AST. + """ + scope: ScopedName + code_elements: List[CodeElement] + + def format(self, allowed_line_length): + raise NotImplementedError(f'Formatting {type(self).__name__} is not supported.') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return self.code_elements + + +@dataclasses.dataclass +class CodeElementFunction(CodeElement): + """ + Represents either a 'func', 'namespace' or 'struct' statement. + For example: + func foo(x, y) -> (z, w): + return (z=x, w=y) + end + """ + # The type of the code element. Either 'func', 'namespace' or 'struct'. + element_type: str + identifier: ExprIdentifier + arguments: IdentifierList + returns: Optional[IdentifierList] + code_block: CodeBlock + + ARGUMENT_SCOPE = ScopedName.from_string('Args') + RETURN_SCOPE = ScopedName.from_string('Return') + + @property + def name(self): + return self.identifier.name + + def format(self, allowed_line_length): + code = self.code_block.format(allowed_line_length=allowed_line_length - INDENTATION) + code = indent(code, INDENTATION) + if self.element_type in ['struct', 'namespace']: + particles = [f'{self.element_type} {self.name}:'] + elif self.returns is not None: + particles = [ + f'{self.element_type} {self.name}(', + create_particle_sublist(self.arguments.get_particles(), ') -> ('), + create_particle_sublist(self.returns.get_particles(), '):')] + else: + particles = [ + f'{self.element_type} {self.name}(', + create_particle_sublist(self.arguments.get_particles(), '):')] + + header = particles_in_lines( + particles=particles, + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, + line_indent=INDENTATION * 2)) + return f'{header}\n{code}end' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.identifier, self.arguments, self.returns, self.code_block] + + +@dataclasses.dataclass +class CodeElementIf(CodeElement): + condition: BoolExpr + main_code_block: CodeBlock + else_code_block: Optional[CodeBlock] + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + cond_particles = ['if ', *self.condition.get_particles()] + cond_particles[-1] = cond_particles[-1] + ':' + code = particles_in_lines( + particles=cond_particles, + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, + line_indent=INDENTATION)) + main_code = self.main_code_block.format( + allowed_line_length=allowed_line_length - INDENTATION) + main_code = indent(main_code, INDENTATION) + code += f'\n{main_code}' + if self.else_code_block is not None: + code += f'else:' + else_code = self.else_code_block.format( + allowed_line_length=allowed_line_length - INDENTATION) + else_code = indent(else_code, INDENTATION) + code += f'\n{else_code}' + code += 'end' + return code + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.condition, self.main_code_block, self.else_code_block] + + +class Directive(AstNode): + @abstractmethod + def format(self): + pass + + +@dataclasses.dataclass +class BuiltinsDirective(Directive): + builtins: List[str] + location: Optional[Location] = LocationField + + def format(self): + return f'%builtins {" ".join(self.builtins)}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class CodeElementDirective(CodeElement): + directive: Directive + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + return self.directive.format() + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.directive] + + +@dataclasses.dataclass +class CodeElementImport(CodeElement): + path: ExprIdentifier + orig_identifier: ExprIdentifier + local_name: Optional[ExprIdentifier] = None + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + return f'from {self.path.format()} import {self.orig_identifier.format()}' + \ + (f' as {self.local_name.format()}' if self.local_name else '') + + @property + def identifier(self): + return self.local_name if self.local_name is not None else self.orig_identifier + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.path, self.orig_identifier, self.local_name] + + +@dataclasses.dataclass +class CodeElementAllocLocals(CodeElement): + """ + Represents a statement of the form "alloc_locals". + """ + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + return 'alloc_locals' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +def is_empty_line(code_element: CommentedCodeElement): + return isinstance(code_element.code_elm, CodeElementEmptyLine) and code_element.comment is None + + +def is_comment_line(code_element: CommentedCodeElement): + return isinstance(code_element.code_elm, CodeElementEmptyLine) and \ + code_element.comment is not None + + +def remove_redundant_empty_lines( + code_elements: List[CommentedCodeElement]) -> List[CommentedCodeElement]: + """ + Returns a new list of code elements where redundant empty lines are removed. + Redundant empty lines are empty lines which are after: + 1. Empty lines. + 2. Labels. + """ + new_code_elements = [] + skip_empty_lines = True + for code_elm in code_elements: + if is_empty_line(code_elm): + # Empty line. + if skip_empty_lines: + continue + skip_empty_lines = True + elif isinstance(code_elm.code_elm, CodeElementLabel): + skip_empty_lines = True + else: + skip_empty_lines = False + new_code_elements.append(code_elm) + return new_code_elements + + +def add_empty_lines_before_labels( + code_elements: List[CommentedCodeElement]) -> List[CommentedCodeElement]: + """ + Makes sure there is an empty line before labels. + The empty line is added before the comment lines preceding the label. + """ + new_code_elements_reversed = [] + add_empty_line = False + for code_elm in code_elements[::-1]: + if add_empty_line: + if is_empty_line(code_elm): + add_empty_line = False + elif not is_comment_line(code_elm): + new_code_elements_reversed.append(CommentedCodeElement( + code_elm=CodeElementEmptyLine(), + comment=None)) + add_empty_line = False + + if isinstance(code_elm.code_elm, CodeElementLabel): + add_empty_line = True + + new_code_elements_reversed.append(code_elm) + + return new_code_elements_reversed[::-1] + + +def fix_comment_spaces(code_elements: List[CommentedCodeElement]) -> List[CommentedCodeElement]: + """ + Comments should start with exactly one space after '#'. When a comment is spread across several + lines, the next lines may start with more than one space. + Returns a copy of code_elements, where comment prefix spaces are fixed. + """ + new_code_elements = [] + allow_additional_comment_spaces = False + for code_elm in code_elements: + # Additional spaces are never allowed in inline comments. + if not is_comment_line(code_elm): + allow_additional_comment_spaces = False + + new_code_elements.append(code_elm.fix_comment_spaces(allow_additional_comment_spaces)) + + if is_comment_line(code_elm): + # Next comment line may have additional spaces. + allow_additional_comment_spaces = True + return new_code_elements diff --git a/src/starkware/cairo/lang/compiler/ast/expr.py b/src/starkware/cairo/lang/compiler/ast/expr.py new file mode 100644 index 00000000..fe231613 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/expr.py @@ -0,0 +1,303 @@ +import dataclasses +import re +from abc import abstractmethod +from typing import List, Optional, Sequence + +from starkware.cairo.lang.compiler.ast.cairo_types import CairoType +from starkware.cairo.lang.compiler.ast.formatting_utils import INDENTATION, LocationField +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.ast.notes import Notes, NotesField +from starkware.cairo.lang.compiler.error_handling import Location +from starkware.cairo.lang.compiler.instruction import Register +from starkware.python.expression_string import ExpressionString + + +class Expression(AstNode): + location: Optional[Location] + + def format(self): + res = str(self.to_expr_str()) + # Indent all lines except for the first. + res = res.replace('\n', '\n' + ' ' * INDENTATION) + # Remove trailing spaces. + res = re.sub(r' +\n', '\n', res) + return res + + @abstractmethod + def to_expr_str(self) -> ExpressionString: + """ + Formats the Expression and returns an ExpressionString. This is useful for automatic + insertion of parentheses (where required). + """ + + +@dataclasses.dataclass +class ExprConst(Expression): + val: int + location: Optional[Location] = LocationField + + def to_expr_str(self): + if self.val >= 0: + return ExpressionString.highest(str(self.val)) + return -ExpressionString.highest(str(-self.val)) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class ExprPyConst(Expression): + code: str + location: Optional[Location] = LocationField + + @classmethod + def from_str(cls, src: str, location: Optional[Location] = None): + assert src.startswith('%[') + assert src.endswith('%]') + code = src[2:-2] + return cls(code, location) + + def to_expr_str(self): + return ExpressionString.highest(f'%[{self.code}%]') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class ExprIdentifier(Expression): + name: str + location: Optional[Location] = LocationField + + def to_expr_str(self): + return ExpressionString.highest(self.name) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +class ArgListItem(AstNode): + """ + Represents an item in function call or return statement. This can be either ExprAssignment or + EllipsisSymbol. + """ + + location: Optional[Location] + + @abstractmethod + def format(self): + pass + + +@dataclasses.dataclass +class EllipsisSymbol(ArgListItem): + location: Optional[Location] = LocationField + + def format(self): + return '...' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class ExprAssignment(ArgListItem): + """ + A code element of the form [ident=]expr. The identifier is optional. + """ + identifier: Optional[ExprIdentifier] + expr: Expression + location: Optional[Location] = LocationField + + def format(self): + if self.identifier is None: + return self.expr.format() + return f'{self.identifier.format()}={self.expr.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.identifier, self.expr] + + +@dataclasses.dataclass +class ArgList(AstNode): + """ + Represents a list of arguments (e.g., to a function call or a return statement). + For example: 'a=1, b=2'. + """ + args: List[ArgListItem] + notes: List[Notes] + has_trailing_comma: bool + location: Optional[Location] = LocationField + + def format(self): + if len(self.args) == 0: + assert len(self.notes) == 1 + return self.notes[0].format() + + code = '' + assert len(self.args) + 1 == len(self.notes) + for notes, arg in zip(self.notes[:-1], self.args): + if code != '': + code += ',' + if notes.empty: + code += ' ' + code += f'{notes.format()}{arg.format()}' + + # Add trailing comma at the end if necessary. + if self.has_trailing_comma: + code += ',' + code += self.notes[-1].format() + return code + + def get_children(self) -> Sequence[Optional[AstNode]]: + return self.args + + +@dataclasses.dataclass +class ExprReg(Expression): + reg: Register + location: Optional[Location] = LocationField + + def to_expr_str(self): + return ExpressionString.highest(self.reg.name.lower()) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class ExprOperator(Expression): + a: Expression + op: str + b: Expression + notes: Notes = NotesField + location: Optional[Location] = LocationField + + def to_expr_str(self): + self.notes.assert_no_comments() + a = self.a.to_expr_str() + b = self.b.to_expr_str() + if not self.notes.empty: + b = b.prepend('\n') + if self.op == '+': + return a + b + elif self.op == '-': + return a - b + elif self.op == '*': + return a * b + elif self.op == '/': + return a / b + else: + raise NotImplementedError(f"Unexpected operator '{self.op}'") + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.a, self.b] + + +@dataclasses.dataclass +class ExprAddressOf(Expression): + """ + Represents an expression of the form "&expr". + """ + expr: Expression + location: Optional[Location] = LocationField + + def to_expr_str(self): + return self.expr.to_expr_str().address_of() + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.expr] + + +@dataclasses.dataclass +class ExprNeg(Expression): + val: Expression + location: Optional[Location] = LocationField + + def to_expr_str(self): + return -self.val.to_expr_str() + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.val] + + +@dataclasses.dataclass +class ExprParentheses(Expression): + val: Expression + notes: Notes = NotesField + location: Optional[Location] = LocationField + + def to_expr_str(self): + return ExpressionString.highest(f'({self.notes.format()}{str(self.val.to_expr_str())})') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.val] + + +@dataclasses.dataclass +class ExprDeref(Expression): + """ + Represents an expression of the form "[expr]". + """ + addr: Expression + notes: Notes = NotesField + location: Optional[Location] = LocationField + + def to_expr_str(self): + self.notes.assert_no_comments() + notes = '' if self.notes.empty else '\n' + return ExpressionString.highest(f'[{notes}{str(self.addr.to_expr_str())}]') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.addr] + + +@dataclasses.dataclass +class ExprCast(Expression): + """ + Represents a cast expression of the form "cast(expr, T)" (which transforms expr to type T). + """ + expr: Expression + dest_type: CairoType + notes: Notes = NotesField + location: Optional[Location] = LocationField + + def to_expr_str(self): + self.notes.assert_no_comments() + notes = '' if self.notes.empty else '\n' + return ExpressionString.highest( + f'cast({notes}{str(self.expr.to_expr_str())}, {self.dest_type.format()})') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.expr, self.dest_type] + + +@dataclasses.dataclass +class ExprTuple(Expression): + members: ArgList + location: Optional[Location] = LocationField + + def to_expr_str(self): + code = self.members.format() + return ExpressionString.highest(f'({code})') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.members] + + +@dataclasses.dataclass +class ExprFutureLabel(Expression): + """ + Represents a future label whose current pc is not known yet. + """ + identifier: ExprIdentifier + + def to_expr_str(self): + return self.identifier.to_expr_str() + + @property + def locaion(self): + return self.identifier.location + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.identifier] diff --git a/src/starkware/cairo/lang/compiler/ast/formatting_utils.py b/src/starkware/cairo/lang/compiler/ast/formatting_utils.py new file mode 100644 index 00000000..afc6c8de --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/formatting_utils.py @@ -0,0 +1,147 @@ +""" +Contains utils that help with formatting of Cairo code. +""" + +import dataclasses +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import field +from typing import List + +from starkware.cairo.lang.compiler.error_handling import LocationError + +INDENTATION = 4 +LocationField = field(default=None, hash=False, compare=False) +max_line_length_ctx_var: ContextVar[int] = ContextVar('max_line_length', default=100) + + +def get_max_line_length(): + return max_line_length_ctx_var.get() + + +@contextmanager +def set_max_line_length(line_length: bool): + """ + Context manager that sets max_line_length context variable. + """ + previous = get_max_line_length() + max_line_length_ctx_var.set(line_length) + yield + max_line_length_ctx_var.set(previous) + + +class FormattingError(LocationError): + pass + + +@dataclasses.dataclass +class ParticleFormattingConfig: + # The maximal line length. + allowed_line_length: int + # The indentation, starting from the second line. + line_indent: int + # The prefix of the first line. + first_line_prefix: str = '' + # At most one item per line. + one_per_line: bool = False + + +class ParticleLineBuilder: + """ + Builds particle lines, wrapping line lengths as needed. + """ + + def __init__(self, config: ParticleFormattingConfig): + self.lines: List[str] = [] + self.line = config.first_line_prefix + self.line_is_new = True + + self.config = config + + def newline(self): + """ + Opens a new line. + """ + if self.line_is_new: + return + self.lines.append(self.line) + self.line_is_new = True + self.line = ' ' * self.config.line_indent + + def add_to_line(self, string): + """ + Adds to current line, opening a new one if needed. + """ + if len(self.line) + len(string) > self.config.allowed_line_length and not self.line_is_new: + self.newline() + self.line += string + self.line_is_new = False + + def finalize(self): + """ + Finalizes the particle lines and returns the result. + """ + if self.line: + self.lines.append(self.line) + return '\n'.join(line.rstrip() for line in self.lines) + + +def create_particle_sublist(lst, end='', separator=', '): + if not lst: + # If the list is empty, return the single element 'end'. + return end + # Concatenate the 'separator' to all elements of the 'lst' and 'end' to the last one. + return [elm + separator for elm in lst[:-1]] + [lst[-1] + end] + + +def particles_in_lines(particles, config: ParticleFormattingConfig): + """ + Receives a list 'particles' that contains strings and particle sublists and generates lines + according to the following rules: + - The first line is not indented. All other lines start with 'line_indent' spaces. + - A line containing more than one particle can be no longer than 'allowed_line_length'. + - A sublist that cannot be fully concatenated to the current line opens a new line. + + Example: + particles_in_lines( + ['func f(', + create_particle_sublist(['x', 'y', 'z'], ') -> ('), + create_particle_sublist(['a', 'b', 'c'], '):')], + 12, 4) + returns '''\ + func f( + x, y, + z) -> ( + a, b, + c):\ + ''' + With a longer line length we will get the lists on the same line: + particles_in_lines( + ['func f(', + create_particle_sublist(['x', 'y', 'z'], ') -> ('), + create_particle_sublist([], '):')], + 19, 4) + returns '''\ + func f( + x, y, z) -> ():\ + ''' + """ + + builder = ParticleLineBuilder(config=config) + + for particle in particles: + if isinstance(particle, str): + builder.add_to_line(particle) + + if isinstance(particle, list): + # If the entire sublist fits in a single line, add it. + if sum(map(len, particle), config.line_indent) < config.allowed_line_length: + builder.add_to_line(''.join(particle)) + continue + builder.newline() + for member in particle: + if config.one_per_line: + builder.newline() + builder.add_to_line(member) + + return builder.finalize() diff --git a/src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py b/src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py new file mode 100644 index 00000000..d5a9c0dd --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py @@ -0,0 +1,82 @@ +from starkware.cairo.lang.compiler.ast.formatting_utils import ( + ParticleFormattingConfig, create_particle_sublist, particles_in_lines) + + +def test_particles_in_lines(): + particles = [ + 'start ', + 'foo ', + 'bar ', + create_particle_sublist(['a', 'b', 'c', 'dddd', 'e', 'f'], '*'), + ' asdf', + ] + expected = """\ +start foo + bar + a, b, c, + dddd, e, + f* asdf\ +""" + assert particles_in_lines( + particles=particles, + config=ParticleFormattingConfig(allowed_line_length=12, line_indent=2), + ) == expected + + particles = [ + 'func f(', + create_particle_sublist(['x', 'y', 'z'], ') -> ('), + create_particle_sublist(['a', 'b', 'c'], '):'), + ] + expected = """\ +func f( + x, y, + z) -> ( + a, b, + c):\ +""" + assert particles_in_lines( + particles=particles, + config=ParticleFormattingConfig(allowed_line_length=12, line_indent=4), + ) == expected + + # Same particles, using one_per_line=True. + expected = """\ +func f( + x, + y, + z) -> ( + a, + b, + c):\ +""" + assert particles_in_lines( + particles=particles, + config=ParticleFormattingConfig( + allowed_line_length=12, line_indent=4, one_per_line=True), + ) == expected + + # Same particles, using one_per_line=True, longer lines. + expected = """\ +func f( + x, y, z) -> ( + a, b, c):\ +""" + assert particles_in_lines( + particles=particles, + config=ParticleFormattingConfig( + allowed_line_length=19, line_indent=4, one_per_line=True), + ) == expected + + particles = [ + 'func f(', + create_particle_sublist(['x', 'y', 'z'], ') -> ('), + create_particle_sublist([], '):'), + ] + expected = """\ +func f( + x, y, z) -> ():\ +""" + assert particles_in_lines( + particles=particles, + config=ParticleFormattingConfig(allowed_line_length=19, line_indent=4), + ) == expected diff --git a/src/starkware/cairo/lang/compiler/ast/instructions.py b/src/starkware/cairo/lang/compiler/ast/instructions.py new file mode 100644 index 00000000..8cf7b297 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/instructions.py @@ -0,0 +1,171 @@ +import dataclasses +from abc import abstractmethod +from typing import Optional, Sequence + +from starkware.cairo.lang.compiler.ast.expr import Expression, ExprIdentifier +from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.error_handling import Location + + +class InstructionBody(AstNode): + """ + Represents the instruction without the flag ap++. + """ + + location: Optional[Location] + + @abstractmethod + def format(self) -> str: + """ + Returns a string representing the instruction. + """ + + +@dataclasses.dataclass +class AssertEqInstruction(InstructionBody): + """ + Represents the instruction "a = b" for two expressions a, b. + """ + + a: Expression + b: Expression + location: Optional[Location] = LocationField + + def format(self): + return f'{self.a.format()} = {self.b.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.a, self.b] + + +@dataclasses.dataclass +class JumpInstruction(InstructionBody): + """ + Represents the instruction "jmp rel/abs". + """ + + val: Expression + relative: bool + location: Optional[Location] = LocationField + + def format(self): + return f'jmp {"rel" if self.relative else "abs"} {self.val.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.val] + + +@dataclasses.dataclass +class JumpToLabelInstruction(InstructionBody): + """ + Represents the instruction "jmp