diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b969ea..f8276a5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,5 +13,6 @@ endif() find_program(PYTHON "python3") include("src/cmake_utils/cmake_rules.cmake") +include("src/starkware/cairo/lang/cairo_cmake_rules.cmake") add_subdirectory(src) diff --git a/LICENSE b/LICENSE index d19ef36..e414e19 100644 --- a/LICENSE +++ b/LICENSE @@ -1,201 +1,14 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ +The files in this repository are governed by several licenses, +which are located under the licenses/ directory. - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION +All code in this project except for the files under +'src/starkware/cairo/' and its subdirectories is +subject to the Apache License, +which can be found under 'licenses/ApacheLicense.txt'. - 1. Definitions. +The code under 'src/starkware/cairo/' and its subdirectories is +subject to the Cairo Toolchain License (Source Available), which +can be found under 'licenses/CairoToolchainLicense.txt'. - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright 2021 StarkWare Industries Ltd. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +For more information regarding licenses visit +https://starkware.co/licenses/. diff --git a/README.md b/README.md index ac8d811..3e930be 100644 --- a/README.md +++ b/README.md @@ -7,4 +7,4 @@ and updates a commitment to the state of the exchange onchain. StarkEx allows exchanges to provide non-custodial trading at scale with high liquidity and lower costs. -This repo holds a collection of tools to support StarkPerpetual - the Stark Exchange for derivatives trading. +This repo holds the Cairo program and a collection of tools to support StarkPerpetual - the Stark Exchange for derivatives trading. diff --git a/licenses/ApacheLicense.txt b/licenses/ApacheLicense.txt new file mode 100644 index 0000000..d19ef36 --- /dev/null +++ b/licenses/ApacheLicense.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021 StarkWare Industries Ltd. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/licenses/CairoToolchainLicense.txt b/licenses/CairoToolchainLicense.txt new file mode 100644 index 0000000..275b21c --- /dev/null +++ b/licenses/CairoToolchainLicense.txt @@ -0,0 +1,35 @@ +Cairo Toolchain License (Source Available) + +Version 1.0 dated December 22, 2020 + +This license contains the terms and conditions under which StarkWare +Industries, Ltd ("StarkWare") makes available its Cairo Toolchain +("Toolchain"). Your use of the Toolchain is subject to these terms and +conditions. + +StarkWare grants you ("Licensee") a license to use the Toolchain, only +for the purpose of developing and compiling Cairo programs. Licensee's +use of the Toolchain is limited to non-commercial use, which means academic, +scientific, or research and development use, or evaluating the Cairo +language and Toolchain. + +StarkWare grants Licensee a license to modify the Toolchain, only as +necessary to fix errors. Licensee may, but is not obligated to, provide +any of Licensee's modifications to StarkWare. This license grants Licensee no +right to distribute the Toolchain or make copies of the Toolchain available +to others. + +These terms do not allow Licensee to sublicense or transfer any of Licensee's +rights to anyone else. These terms do not imply any other licenses not +expressly granted in this license. + +If Licensee violates any of these terms, or uses the Toolchain in a way not +authorized under this license, the license granted to Licensee ends immediately. +If Licensee makes, or authorizes any other person to make, any written claim +that the Toolchain infringes or contributes to infringement of any patent, all +rights granted to Licensee under this license end immediately. + +As far as the law allows, the Toolchain is provided AS IS, without any warranty +or condition, and StarkWare will not be liable to Licensee for any damages +arising out of these terms or the use or nature of the Toolchain, under any kind +of legal claim. diff --git a/scripts/requirements-deps.json b/scripts/requirements-deps.json index ea7eae0..9803e3b 100644 --- a/scripts/requirements-deps.json +++ b/scripts/requirements-deps.json @@ -7,10 +7,63 @@ "package_name": "PyYAML" } }, + { + "dependencies": [ + { + "installed_version": "3.0.1", + "key": "async-timeout", + "package_name": "async-timeout", + "required_version": ">=3.0,<4.0" + }, + { + "installed_version": "21.2.0", + "key": "attrs", + "package_name": "attrs", + "required_version": ">=17.3.0" + }, + { + "installed_version": "4.0.0", + "key": "chardet", + "package_name": "chardet", + "required_version": ">=2.0,<5.0" + }, + { + "installed_version": "5.1.0", + "key": "multidict", + "package_name": "multidict", + "required_version": ">=4.5,<7.0" + }, + { + "installed_version": "3.10.0.0", + "key": "typing-extensions", + "package_name": "typing-extensions", + "required_version": ">=3.6.5" + }, + { + "installed_version": "1.6.3", + "key": "yarl", + "package_name": "yarl", + "required_version": ">=1.0,<2.0" + } + ], + "package": { + "installed_version": "3.7.4.post0", + "key": "aiohttp", + "package_name": "aiohttp" + } + }, { "dependencies": [], "package": { - "installed_version": "20.3.0", + "installed_version": "3.0.1", + "key": "async-timeout", + "package_name": "async-timeout" + } + }, + { + "dependencies": [], + "package": { + "installed_version": "21.2.0", "key": "attrs", "package_name": "attrs" } @@ -34,7 +87,7 @@ { "dependencies": [], "package": { - "installed_version": "2020.12.5", + "installed_version": "2021.5.30", "key": "certifi", "package_name": "certifi" } @@ -65,14 +118,14 @@ { "dependencies": [ { - "installed_version": "1.15.0", + "installed_version": "1.16.0", "key": "six", "package_name": "six", "required_version": ">=1.9.0" } ], "package": { - "installed_version": "0.16.1", + "installed_version": "0.17.0", "key": "ecdsa", "package_name": "ecdsa" } @@ -318,14 +371,14 @@ "required_version": ">=3.6.4" }, { - "installed_version": "3.4.1", + "installed_version": "3.5.0", "key": "zipp", "package_name": "zipp", "required_version": ">=0.5" } ], "package": { - "installed_version": "4.0.1", + "installed_version": "4.6.1", "key": "importlib-metadata", "package_name": "importlib-metadata" } @@ -354,7 +407,7 @@ } ], "package": { - "installed_version": "0.7.0a1", + "installed_version": "0.7.0", "key": "ipfshttpclient", "package_name": "ipfshttpclient" } @@ -362,31 +415,31 @@ { "dependencies": [ { - "installed_version": "20.3.0", + "installed_version": "21.2.0", "key": "attrs", "package_name": "attrs", "required_version": ">=17.4.0" }, { - "installed_version": "4.0.1", + "installed_version": "4.6.1", "key": "importlib-metadata", "package_name": "importlib-metadata", "required_version": null }, { - "installed_version": "0.17.3", + "installed_version": "0.18.0", "key": "pyrsistent", "package_name": "pyrsistent", "required_version": ">=0.14.0" }, { - "installed_version": "56.0.0", + "installed_version": "57.0.0", "key": "setuptools", "package_name": "setuptools", "required_version": null }, { - "installed_version": "1.15.0", + "installed_version": "1.16.0", "key": "six", "package_name": "six", "required_version": ">=1.11.0" @@ -398,6 +451,14 @@ "package_name": "jsonschema" } }, + { + "dependencies": [], + "package": { + "installed_version": "0.8.5", + "key": "lark-parser", + "package_name": "lark-parser" + } + }, { "dependencies": [], "package": { @@ -406,6 +467,65 @@ "package_name": "lru-dict" } }, + { + "dependencies": [], + "package": { + "installed_version": "3.12.2", + "key": "marshmallow", + "package_name": "marshmallow" + } + }, + { + "dependencies": [ + { + "installed_version": "3.12.2", + "key": "marshmallow", + "package_name": "marshmallow", + "required_version": ">=3.0.0,<4.0" + }, + { + "installed_version": "0.7.1", + "key": "typing-inspect", + "package_name": "typing-inspect", + "required_version": null + } + ], + "package": { + "installed_version": "8.4.1", + "key": "marshmallow-dataclass", + "package_name": "marshmallow-dataclass" + } + }, + { + "dependencies": [ + { + "installed_version": "3.12.2", + "key": "marshmallow", + "package_name": "marshmallow", + "required_version": ">=2.0.0" + } + ], + "package": { + "installed_version": "1.5.1", + "key": "marshmallow-enum", + "package_name": "marshmallow-enum" + } + }, + { + "dependencies": [ + { + "installed_version": "3.12.2", + "key": "marshmallow", + "package_name": "marshmallow", + "required_version": ">=3.0.0,<4.0.0" + } + ], + "package": { + "installed_version": "3.0.1", + "key": "marshmallow-oneofschema", + "package_name": "marshmallow-oneofschema" + } + }, { "dependencies": [], "package": { @@ -429,7 +549,7 @@ "required_version": null }, { - "installed_version": "1.15.0", + "installed_version": "1.16.0", "key": "six", "package_name": "six", "required_version": null @@ -447,6 +567,14 @@ "package_name": "multiaddr" } }, + { + "dependencies": [], + "package": { + "installed_version": "5.1.0", + "key": "multidict", + "package_name": "multidict" + } + }, { "dependencies": [], "package": { @@ -473,7 +601,7 @@ } ], "package": { - "installed_version": "20.9", + "installed_version": "21.0", "key": "packaging", "package_name": "packaging" } @@ -481,7 +609,7 @@ { "dependencies": [ { - "installed_version": "1.15.0", + "installed_version": "1.16.0", "key": "six", "package_name": "six", "required_version": ">=1.9.0" @@ -496,7 +624,7 @@ { "dependencies": [], "package": { - "installed_version": "21.1.1", + "installed_version": "21.1.2", "key": "pip", "package_name": "pip" } @@ -504,7 +632,7 @@ { "dependencies": [ { - "installed_version": "21.1.1", + "installed_version": "21.1.2", "key": "pip", "package_name": "pip", "required_version": ">=6.0.0" @@ -519,7 +647,7 @@ { "dependencies": [ { - "installed_version": "4.0.1", + "installed_version": "4.6.1", "key": "importlib-metadata", "package_name": "importlib-metadata", "required_version": ">=0.12" @@ -534,14 +662,14 @@ { "dependencies": [ { - "installed_version": "1.15.0", + "installed_version": "1.16.0", "key": "six", "package_name": "six", "required_version": ">=1.9" } ], "package": { - "installed_version": "3.15.8", + "installed_version": "3.17.3", "key": "protobuf", "package_name": "protobuf" } @@ -573,7 +701,7 @@ { "dependencies": [], "package": { - "installed_version": "0.17.3", + "installed_version": "0.18.0", "key": "pyrsistent", "package_name": "pyrsistent" } @@ -581,13 +709,13 @@ { "dependencies": [ { - "installed_version": "20.3.0", + "installed_version": "21.2.0", "key": "attrs", "package_name": "attrs", "required_version": ">=19.2.0" }, { - "installed_version": "4.0.1", + "installed_version": "4.6.1", "key": "importlib-metadata", "package_name": "importlib-metadata", "required_version": ">=0.12" @@ -599,7 +727,7 @@ "required_version": null }, { - "installed_version": "20.9", + "installed_version": "21.0", "key": "packaging", "package_name": "packaging", "required_version": null @@ -624,7 +752,7 @@ } ], "package": { - "installed_version": "6.2.3", + "installed_version": "6.2.4", "key": "pytest", "package_name": "pytest" } @@ -632,7 +760,7 @@ { "dependencies": [ { - "installed_version": "2020.12.5", + "installed_version": "2021.5.30", "key": "certifi", "package_name": "certifi", "required_version": ">=2017.4.17" @@ -650,7 +778,7 @@ "required_version": ">=2.5,<3" }, { - "installed_version": "1.26.4", + "installed_version": "1.26.6", "key": "urllib3", "package_name": "urllib3", "required_version": ">=1.21.1,<1.27" @@ -680,7 +808,7 @@ { "dependencies": [], "package": { - "installed_version": "56.0.0", + "installed_version": "57.0.0", "key": "setuptools", "package_name": "setuptools" } @@ -688,7 +816,7 @@ { "dependencies": [], "package": { - "installed_version": "1.15.0", + "installed_version": "1.16.0", "key": "six", "package_name": "six" } @@ -732,10 +860,31 @@ "package_name": "typing-extensions" } }, + { + "dependencies": [ + { + "installed_version": "0.4.3", + "key": "mypy-extensions", + "package_name": "mypy-extensions", + "required_version": ">=0.3.0" + }, + { + "installed_version": "3.10.0.0", + "key": "typing-extensions", + "package_name": "typing-extensions", + "required_version": ">=3.7.4" + } + ], + "package": { + "installed_version": "0.7.1", + "key": "typing-inspect", + "package_name": "typing-inspect" + } + }, { "dependencies": [], "package": { - "installed_version": "1.26.4", + "installed_version": "1.26.6", "key": "urllib3", "package_name": "urllib3" } @@ -750,6 +899,12 @@ }, { "dependencies": [ + { + "installed_version": "3.7.4.post0", + "key": "aiohttp", + "package_name": "aiohttp", + "required_version": ">=3.7.4.post0,<4" + }, { "installed_version": "2.1.1", "key": "eth-abi", @@ -787,10 +942,10 @@ "required_version": ">=0.1.0,<1.0.0" }, { - "installed_version": "0.7.0a1", + "installed_version": "0.7.0", "key": "ipfshttpclient", "package_name": "ipfshttpclient", - "required_version": "==0.7.0a1" + "required_version": "==0.7.0" }, { "installed_version": "3.2.0", @@ -805,7 +960,7 @@ "required_version": ">=1.1.6,<2.0.0" }, { - "installed_version": "3.15.8", + "installed_version": "3.17.3", "key": "protobuf", "package_name": "protobuf", "required_version": ">=3.10.0,<4" @@ -823,14 +978,14 @@ "required_version": ">=3.7.4.1,<4" }, { - "installed_version": "8.1", + "installed_version": "9.1", "key": "websockets", "package_name": "websockets", - "required_version": ">=8.1.0,<9.0.0" + "required_version": ">=9.1,<10" } ], "package": { - "installed_version": "5.19.0", + "installed_version": "5.21.0", "key": "web3", "package_name": "web3" } @@ -838,7 +993,7 @@ { "dependencies": [], "package": { - "installed_version": "8.1", + "installed_version": "9.1", "key": "websockets", "package_name": "websockets" } @@ -851,10 +1006,37 @@ "package_name": "wheel" } }, + { + "dependencies": [ + { + "installed_version": "2.10", + "key": "idna", + "package_name": "idna", + "required_version": ">=2.0" + }, + { + "installed_version": "5.1.0", + "key": "multidict", + "package_name": "multidict", + "required_version": ">=4.0" + }, + { + "installed_version": "3.10.0.0", + "key": "typing-extensions", + "package_name": "typing-extensions", + "required_version": ">=3.7.4" + } + ], + "package": { + "installed_version": "1.6.3", + "key": "yarl", + "package_name": "yarl" + } + }, { "dependencies": [], "package": { - "installed_version": "3.4.1", + "installed_version": "3.5.0", "key": "zipp", "package_name": "zipp" } diff --git a/scripts/requirements-gen.txt b/scripts/requirements-gen.txt index 8a91ed4..657f9bf 100644 --- a/scripts/requirements-gen.txt +++ b/scripts/requirements-gen.txt @@ -1,10 +1,16 @@ ecdsa eth-hash[pycryptodome]==0.2.0 fastecdsa +lark-parser==0.8.5 +marshmallow-dataclass>=7.1.0 +marshmallow-enum +marshmallow-oneofschema +marshmallow>=3.2.1 mpmath mypy-extensions +numpy pipdeptree pytest sympy -web3 +Web3 PyYAML>=5.3.1,<5.4 diff --git a/scripts/requirements.txt b/scripts/requirements.txt index d243868..76203bf 100644 --- a/scripts/requirements.txt +++ b/scripts/requirements.txt @@ -1,12 +1,14 @@ # This file is autogenerated. Do not edit manually. -attrs==20.3.0 +aiohttp==3.7.4.post0 +async-timeout==3.0.1 +attrs==21.2.0 base58==2.1.0 bitarray==1.2.2 -certifi==2020.12.5 +certifi==2021.5.30 chardet==4.0.0 cytoolz==0.11.0 -ecdsa==0.16.1 +ecdsa==0.17.0 eth-abi==2.1.1 eth-account==0.5.4 eth-hash[pycryptodome]==0.2.0 @@ -18,35 +20,43 @@ eth-utils==1.9.5 fastecdsa==2.1.5 hexbytes==0.2.1 idna==2.10 -importlib-metadata==4.0.1 +importlib-metadata==4.6.1 iniconfig==1.1.1 -ipfshttpclient==0.7.0a1 +ipfshttpclient==0.7.0 jsonschema==3.2.0 +lark-parser==0.8.5 lru-dict==1.1.7 +marshmallow==3.12.2 +marshmallow-dataclass==8.4.1 +marshmallow-enum==1.5.1 +marshmallow-oneofschema==3.0.1 mpmath==1.2.1 multiaddr==0.0.9 +multidict==5.1.0 mypy-extensions==0.4.3 netaddr==0.8.0 -packaging==20.9 +packaging==21.0 parsimonious==0.8.1 pipdeptree==2.0.0 pluggy==0.13.1 -protobuf==3.15.8 +protobuf==3.17.3 py==1.10.0 pycryptodome==3.10.1 pyparsing==2.4.7 -pyrsistent==0.17.3 -pytest==6.2.3 +pyrsistent==0.18.0 +pytest==6.2.4 PyYAML==5.3.1 requests==2.25.1 rlp==2.0.1 -six==1.15.0 +six==1.16.0 sympy==1.8 toml==0.10.2 toolz==0.11.1 typing-extensions==3.10.0.0 -urllib3==1.26.4 +typing-inspect==0.7.1 +urllib3==1.26.6 varint==1.0.2 -web3==5.19.0 -websockets==8.1 -zipp==3.4.1 +web3==5.21.0 +websockets==9.1 +yarl==1.6.3 +zipp==3.5.0 diff --git a/src/services/exchange/cairo/definitions/constants.cairo b/src/services/exchange/cairo/definitions/constants.cairo new file mode 100644 index 0000000..d28f3f0 --- /dev/null +++ b/src/services/exchange/cairo/definitions/constants.cairo @@ -0,0 +1,4 @@ +const AMOUNT_UPPER_BOUND = %[2**64%] +const EXPIRATION_TIMESTAMP_UPPER_BOUND = %[2**32%] +const NONCE_UPPER_BOUND = %[2**32%] +const VAULT_ID_UPPER_BOUND = %[2**64%] diff --git a/src/services/exchange/cairo/order.cairo b/src/services/exchange/cairo/order.cairo new file mode 100644 index 0000000..f380a8b --- /dev/null +++ b/src/services/exchange/cairo/order.cairo @@ -0,0 +1,8 @@ +# Common struct for user signed orders (limit_order, withdrawal, transfer, etc.). +struct OrderBase: + member nonce : felt + member public_key : felt + member expiration_timestamp : felt + member signature_r : felt + member signature_s : felt +end diff --git a/src/services/exchange/cairo/signature_message_hashes.cairo b/src/services/exchange/cairo/signature_message_hashes.cairo new file mode 100644 index 0000000..c09cfae --- /dev/null +++ b/src/services/exchange/cairo/signature_message_hashes.cairo @@ -0,0 +1,135 @@ +from services.exchange.cairo.definitions.constants import ( + AMOUNT_UPPER_BOUND, EXPIRATION_TIMESTAMP_UPPER_BOUND, NONCE_UPPER_BOUND, VAULT_ID_UPPER_BOUND) +from services.exchange.cairo.order import OrderBase +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.hash import hash2 + +struct ExchangeLimitOrder: + member base : OrderBase* + member amount_buy : felt + member amount_sell : felt + member amount_fee : felt + member asset_id_buy : felt + member asset_id_sell : felt + member asset_id_fee : felt + member vault_buy : felt + member vault_sell : felt + member vault_fee : felt +end + +# limit_order_hash: +# Computes the hash of a limit order. +# +# The hash is defined as h(h(h(h(w1, w2), w3), w4), w5) where h is +# Starkware's Pedersen hash function and w1,...w5 are as follows: +# w1= token_sell +# w2= token_buy +# w3= token_fee +# w4= amount_sell (64 bit) || amount_buy (64 bit) || amount_fee (64 bit) || nonce (32 bit) +# w5= 0x3 (10 bit) || vault_fee (64 bit) || vault_sell (64 bit) || vault_buy (64 bit) +# || expiration_timestamp (32 bit) || 0 (17 bit) +# +# Assumptions: +# amount_sell, amount_buy, amount_fee < AMOUNT_UPPER_BOUND +# nonce < NONCE_UPPER_BOUND +# vault_sell, vault_buy, vault_fee < VAULT_ID_UPPER_BOUND +# expiration_timestamp < EXPIRATION_TIMESTAMP_UPPER_BOUND +func limit_order_hash{pedersen_ptr : HashBuiltin*}(limit_order : ExchangeLimitOrder*) -> ( + limit_order_hash): + let (msg) = hash2{hash_ptr=pedersen_ptr}( + x=limit_order.asset_id_sell, y=limit_order.asset_id_buy) + + let (msg) = hash2{hash_ptr=pedersen_ptr}(x=msg, y=limit_order.asset_id_fee) + + let packed_message0 = limit_order.amount_sell + let packed_message0 = packed_message0 * AMOUNT_UPPER_BOUND + limit_order.amount_buy + let packed_message0 = packed_message0 * AMOUNT_UPPER_BOUND + limit_order.amount_fee + + let packed_message0 = packed_message0 * NONCE_UPPER_BOUND + limit_order.base.nonce + let (msg) = hash2{hash_ptr=pedersen_ptr}(x=msg, y=packed_message0) + + const LIMIT_ORDER_WITH_FEES = 3 + let packed_message1 = LIMIT_ORDER_WITH_FEES + let packed_message1 = packed_message1 * VAULT_ID_UPPER_BOUND + limit_order.vault_fee + let packed_message1 = packed_message1 * VAULT_ID_UPPER_BOUND + limit_order.vault_sell + let packed_message1 = packed_message1 * VAULT_ID_UPPER_BOUND + limit_order.vault_buy + let expiration_timestamp = limit_order.base.expiration_timestamp + let packed_message1 = packed_message1 * EXPIRATION_TIMESTAMP_UPPER_BOUND + + expiration_timestamp + let packed_message1 = packed_message1 * %[2**17%] # Padding. + + let (msg) = hash2{hash_ptr=pedersen_ptr}(x=msg, y=packed_message1) + + return (limit_order_hash=msg) +end + +struct ExchangeTransfer: + member base : OrderBase* + # sender_public_key = base.public_key. + member sender_vault_id : felt + member receiver_public_key : felt + member receiver_vault_id : felt + member amount : felt + member asset_id : felt + member src_fee_vault_id : felt + member asset_id_fee : felt + member max_amount_fee : felt +end + +# transfer_hash: +# Computes the hash of (possibly conditional) transfer request. +# +# The hash is defined as h(h(w1, w2), w3) for a normal transfer, where h is Starkware's Pedersen +# hash function and: +# w1 = h(h(asset_id, asset_id_fee), receiver_public_key) +# w2 = sender_vault_id (64 bit) || receiver_vault_id (64 bit) +# || src_fee_vault_id (64 bit) || nonce (32 bit) +# w3 = 0x4 (15 bit) || amount (64 bit) || max_amount_fee (64 bit) || expiration_timestamp (32 bit) +# || 0 (81 bit) +# where nonce and expiration_timestamp are under ExchangeTransfer.base. +# +# In case of a conditional transfer the hash is defined as h(h(h(w1, condition), w2), w3*) where +# w3* is the same as w3 except for the first element replaced with 0x5 (instead of 0x4). +# +# Assumptions: +# 0 <= nonce < NONCE_UPPER_BOUND +# 0 <= sender_vault_id, receiver_vault_id, src_fee_vault_id < VAULT_ID_UPPER_BOUND +# 0 <= amount, max_amount_fee < AMOUNT_UPPER_BOUND +# 0 <= expiration_timestamp < EXPIRATION_TIMESTAMP_UPPER_BOUND. +func transfer_hash{pedersen_ptr : HashBuiltin*}(transfer : ExchangeTransfer*, condition : felt) -> ( + transfer_hash): + alloc_locals + const TRANSFER_ORDER_TYPE = 4 + const CONDITIONAL_TRANSFER_ORDER_TYPE = 5 + let (msg) = hash2{hash_ptr=pedersen_ptr}(x=transfer.asset_id, y=transfer.asset_id_fee) + let (msg) = hash2{hash_ptr=pedersen_ptr}(x=msg, y=transfer.receiver_public_key) + + # Add condition to the signature hash if exists. + if condition != 0: + let (msg) = hash2{hash_ptr=pedersen_ptr}(x=msg, y=condition) + end + + # The sender is the one that pays the fee. + let src_fee_vault_id = transfer.sender_vault_id + let packed_message0 = transfer.sender_vault_id + let packed_message0 = packed_message0 * VAULT_ID_UPPER_BOUND + transfer.receiver_vault_id + let packed_message0 = packed_message0 * VAULT_ID_UPPER_BOUND + src_fee_vault_id + let packed_message0 = packed_message0 * NONCE_UPPER_BOUND + transfer.base.nonce + + let (msg) = hash2{hash_ptr=pedersen_ptr}(x=msg, y=packed_message0) + + if condition == 0: + # Normal Transfer. + tempvar packed_message1 = TRANSFER_ORDER_TYPE + else: + # Conditional transfer. + tempvar packed_message1 = CONDITIONAL_TRANSFER_ORDER_TYPE + end + let packed_message1 = packed_message1 * AMOUNT_UPPER_BOUND + transfer.amount + let packed_message1 = packed_message1 * AMOUNT_UPPER_BOUND + transfer.max_amount_fee + let packed_message1 = ( + packed_message1 * EXPIRATION_TIMESTAMP_UPPER_BOUND + transfer.base.expiration_timestamp) + let packed_message1 = packed_message1 * %[2**81%] # Padding. + let (msg) = hash2{hash_ptr=pedersen_ptr}(x=msg, y=packed_message1) + return (transfer_hash=msg) +end diff --git a/src/services/perpetual/CMakeLists.txt b/src/services/perpetual/CMakeLists.txt index c031871..b3f13fe 100644 --- a/src/services/perpetual/CMakeLists.txt +++ b/src/services/perpetual/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(public) +add_subdirectory(cairo) diff --git a/src/services/perpetual/cairo/CMakeLists.txt b/src/services/perpetual/cairo/CMakeLists.txt new file mode 100644 index 0000000..14013d4 --- /dev/null +++ b/src/services/perpetual/cairo/CMakeLists.txt @@ -0,0 +1,42 @@ +cairo_compile(perpetual_cairo_program + perpetual_cairo_compiled.json main.cairo "--debug_info_with_source" +) + +python_lib(perpetual_cairo_program_lib + PREFIX services/perpetual/cairo + + ARTIFACTS + "${CMAKE_CURRENT_BINARY_DIR}/perpetual_cairo_compiled.json perpetual_cairo_compiled.json" + + LIBS + cairo_common_lib +) + +add_dependencies(perpetual_cairo_program_lib perpetual_cairo_program) + +python_lib(perpetual_cairo_program_hash_test_lib + PREFIX services/perpetual/cairo + FILES + program_hash_test.py + LIBS + cairo_hash_program_lib + perpetual_cairo_program_lib + pip_pytest +) + +python_venv(perpetual_cairo_program_hash_test_venv + PYTHON python3.7 + LIBS + perpetual_cairo_program_hash_test_lib +) + +python_test(perpetual_cairo_program_hash_test + VENV perpetual_cairo_program_hash_test_venv + TESTED_MODULES services/perpetual/cairo +) + +python_exe(generate_perpetual_cairo_program_hash + VENV perpetual_cairo_program_hash_test_venv + MODULE services.perpetual.cairo.program_hash_test + ARGS "--fix" +) diff --git a/src/services/perpetual/cairo/definitions/constants.cairo b/src/services/perpetual/cairo/definitions/constants.cairo new file mode 100644 index 0000000..3ef3a57 --- /dev/null +++ b/src/services/perpetual/cairo/definitions/constants.cairo @@ -0,0 +1,53 @@ +from services.exchange.cairo.definitions.constants import ( + AMOUNT_UPPER_BOUND, EXPIRATION_TIMESTAMP_UPPER_BOUND, NONCE_UPPER_BOUND) + +# This is the lower bound for actual synthetic asset and limit order collateral amounts. Those +# amounts can't be 0 to prevent order replay and arbitrary actual fees. +const POSITIVE_AMOUNT_LOWER_BOUND = 1 +# ASSET_ID_UPPER_BOUND is set so that PositionAsset could be packed into a field element. +const ASSET_ID_UPPER_BOUND = %[2**120%] + +# A valid balance satisfies BALANCE_LOWER_BOUND < balance < BALANCE_UPPER_BOUND. +const BALANCE_UPPER_BOUND = %[2**63%] +const BALANCE_LOWER_BOUND = -BALANCE_UPPER_BOUND + +const TOTAL_VALUE_UPPER_BOUND = %[2**63%] +const TOTAL_VALUE_LOWER_BOUND = -%[2**63%] + +const TOTAL_RISK_UPPER_BOUND = %[2**64%] + +const N_ASSETS_UPPER_BOUND = %[2**16%] +const POSITION_MAX_SUPPORTED_N_ASSETS = %[2**6%] + +# Fixed point (.32) representation of the number 1. +const FXP_32_ONE = %[2**32%] +# Oracle prices are signed by external entities, which use a fixed point representation where +# 10**18 is 1.0 . +const EXTERNAL_PRICE_FIXED_POINT_UNIT = %[10**18%] + +const ORACLE_PRICE_QUORUM_LOWER_BOUND = %[1%] +const ORACLE_PRICE_QUORUM_UPPER_BOUND = %[2**32%] + +const POSITION_ID_UPPER_BOUND = %[2**64%] +const ORDER_ID_UPPER_BOUND = %[2**64%] +# Fixed point (32.32) +const FUNDING_INDEX_UPPER_BOUND = %[2**63%] +const FUNDING_INDEX_LOWER_BOUND = -%[2**63%] + +# Fixed point (0.32) +const RISK_FACTOR_LOWER_BOUND = %[1%] +const RISK_FACTOR_UPPER_BOUND = FXP_32_ONE + +# Fixed point (32.32) +const PRICE_LOWER_BOUND = 1 +const PRICE_UPPER_BOUND = %[2**64%] + +const EXTERNAL_PRICE_UPPER_BOUND = %[2**120%] + +const ASSET_RESOLUTION_LOWER_BOUND = %[1%] +const ASSET_RESOLUTION_UPPER_BOUND = %[2**64%] +const COLLATERAL_ASSET_ID_UPPER_BOUND = %[2**250%] + +# General Cairo constants. +const SIGNED_MESSAGE_BOUND = %[2**251%] +const RANGE_CHECK_BOUND = %[2**128%] diff --git a/src/services/perpetual/cairo/definitions/general_config.cairo b/src/services/perpetual/cairo/definitions/general_config.cairo new file mode 100644 index 0000000..c84984b --- /dev/null +++ b/src/services/perpetual/cairo/definitions/general_config.cairo @@ -0,0 +1,55 @@ +# Information about the unique collateral asset of the system. +struct CollateralAssetInfo: + member asset_id : felt + # Resolution: Each unit of balance in the oracle is worth this much units in our system. + member resolution : felt +end + +# Information about the unique fee position of the system. All fees are paid to it. +struct FeePositionInfo: + member position_id : felt + member public_key : felt +end + +# Information about a synthetic asset in the system. +struct SyntheticAssetInfo: + member asset_id : felt # Asset id. + # Resolution: Each unit of balance in the oracle is worth this much units in our system. + member resolution : felt + # 32.32 fixed point number indicating the risk factor of the asset. This is used in deciding if + # a position is well leveraged. + member risk_factor : felt + # A list of IDs associated with the asset, on which the oracle price providers sign. + member n_oracle_price_signed_asset_ids : felt + member oracle_price_signed_asset_ids : felt* + # The minimum amounts of signatures required to sign on a price. + member oracle_price_quorum : felt + # A list of oracle signer public keys. + member n_oracle_price_signers : felt + member oracle_price_signers : felt* +end + +# Configuration for timestamp validation. +struct TimestampValidationConfig: + member price_validity_period : felt + member funding_validity_period : felt +end + +struct GeneralConfig: + # 32.32 fixed point number, indicating the maximum rate of change of a normalized funding index. + # Units are (1) / (time * price) + member max_funding_rate : felt + # See CollateralAssetInfo. + member collateral_asset_info : CollateralAssetInfo* + # See FeePositionInfo. + member fee_position_info : FeePositionInfo* + # Information about the synthetic assets in the system. See SyntheticAssetInfo. + member n_synthetic_assets_info : felt + member synthetic_assets_info : SyntheticAssetInfo* + # Height of the merkle tree in which positions are kept. + member positions_tree_height : felt + # Height of the merkle tree in which orders are kept. + member orders_tree_height : felt + # See TimestampValidationConfig. + member timestamp_validation_config : TimestampValidationConfig* +end diff --git a/src/services/perpetual/cairo/definitions/general_config_hash.cairo b/src/services/perpetual/cairo/definitions/general_config_hash.cairo new file mode 100644 index 0000000..a333bf2 --- /dev/null +++ b/src/services/perpetual/cairo/definitions/general_config_hash.cairo @@ -0,0 +1,114 @@ +from services.perpetual.cairo.definitions.general_config import ( + CollateralAssetInfo, FeePositionInfo, GeneralConfig, SyntheticAssetInfo, + TimestampValidationConfig) +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.hash_state import ( + HashState, hash_finalize, hash_init, hash_update, hash_update_single) + +# A synthetic asset entry contaning tis asset id and its config's hash. +struct AssetConfigHashEntry: + member asset_id : felt + member config_hash : felt +end + +# Calculate the hash of a SyntheticAssetInfo. +func synthetic_asset_info_hash{pedersen_ptr : HashBuiltin*}( + synthetic_asset_info_ptr : SyntheticAssetInfo*) -> (hash): + let hash_ptr = pedersen_ptr + with hash_ptr: + let (hash_state_ptr) = hash_init() + let (hash_state_ptr) = hash_update_single(hash_state_ptr, synthetic_asset_info_ptr.asset_id) + let (hash_state_ptr) = hash_update_single( + hash_state_ptr, synthetic_asset_info_ptr.resolution) + let (hash_state_ptr) = hash_update_single( + hash_state_ptr, synthetic_asset_info_ptr.risk_factor) + let (hash_state_ptr) = hash_update( + hash_state_ptr, + synthetic_asset_info_ptr.oracle_price_signed_asset_ids, + synthetic_asset_info_ptr.n_oracle_price_signed_asset_ids) + let (hash_state_ptr) = hash_update_single( + hash_state_ptr, synthetic_asset_info_ptr.oracle_price_quorum) + let (hash_state_ptr) = hash_update( + hash_state_ptr, + synthetic_asset_info_ptr.oracle_price_signers, + synthetic_asset_info_ptr.n_oracle_price_signers) + + static_assert SyntheticAssetInfo.SIZE == 8 + let (hash) = hash_finalize(hash_state_ptr) + end + let pedersen_ptr = hash_ptr + return (hash=hash) +end + +# Calculates the hash of a GeneralConfig. The returned value is the hash of all fields except the +# synthetic assets info. To get the hashes of the synthetic assets, use +# general_config_hash_synthetic_assets. +func general_config_hash{pedersen_ptr : HashBuiltin*}(general_config_ptr : GeneralConfig*) -> ( + hash): + let hash_ptr = pedersen_ptr + with hash_ptr: + let (hash_state_ptr) = hash_init() + let (hash_state_ptr) = hash_update_single( + hash_state_ptr, general_config_ptr.max_funding_rate) + let (hash_state_ptr) = hash_update_single( + hash_state_ptr, general_config_ptr.collateral_asset_info.asset_id) + let (hash_state_ptr) = hash_update_single( + hash_state_ptr, general_config_ptr.collateral_asset_info.resolution) + let (hash_state_ptr) = hash_update_single( + hash_state_ptr, general_config_ptr.fee_position_info.position_id) + let (hash_state_ptr) = hash_update_single( + hash_state_ptr, general_config_ptr.fee_position_info.public_key) + let (hash_state_ptr) = hash_update_single( + hash_state_ptr, general_config_ptr.positions_tree_height) + let (hash_state_ptr) = hash_update_single( + hash_state_ptr, general_config_ptr.orders_tree_height) + let (hash_state_ptr) = hash_update_single( + hash_state_ptr, general_config_ptr.timestamp_validation_config.price_validity_period) + let (hash_state_ptr) = hash_update_single( + hash_state_ptr, general_config_ptr.timestamp_validation_config.funding_validity_period) + + static_assert GeneralConfig.SIZE == 8 + let (hash) = hash_finalize(hash_state_ptr) + end + let pedersen_ptr = hash_ptr + return (hash=hash) +end + +func synthetic_assets_info_to_asset_configs{pedersen_ptr : HashBuiltin*}( + output_ptr : AssetConfigHashEntry*, n_synthetic_assets_info, + synthetic_assets_info : SyntheticAssetInfo*) -> (): + if n_synthetic_assets_info == 0: + return () + end + + let (hash) = synthetic_asset_info_hash(synthetic_assets_info) + assert output_ptr.asset_id = synthetic_assets_info.asset_id + assert output_ptr.config_hash = hash + return synthetic_assets_info_to_asset_configs( + output_ptr=output_ptr + AssetConfigHashEntry.SIZE, + n_synthetic_assets_info=n_synthetic_assets_info - 1, + synthetic_assets_info=synthetic_assets_info + SyntheticAssetInfo.SIZE) +end + +# Calculates the hash of the synthetic assets of a GeneralConfig. Returns a list of each synthetic +# asset info's hash. +func general_config_hash_synthetic_assets{pedersen_ptr : HashBuiltin*}( + general_config_ptr : GeneralConfig*) -> ( + n_asset_configs, asset_configs : AssetConfigHashEntry*): + local asset_configs : AssetConfigHashEntry* + alloc_locals + + %{ + ids.asset_configs = asset_configs = segments.add() + segments.finalize( + asset_configs.segment_index, + ids.general_config_ptr.n_synthetic_assets_info * ids.AssetConfigHashEntry.SIZE + ) + %} + + synthetic_assets_info_to_asset_configs( + output_ptr=asset_configs, + n_synthetic_assets_info=general_config_ptr.n_synthetic_assets_info, + synthetic_assets_info=general_config_ptr.synthetic_assets_info) + return (n_asset_configs=general_config_ptr.n_synthetic_assets_info, asset_configs=asset_configs) +end diff --git a/src/services/perpetual/cairo/definitions/objects.cairo b/src/services/perpetual/cairo/definitions/objects.cairo new file mode 100644 index 0000000..613e305 --- /dev/null +++ b/src/services/perpetual/cairo/definitions/objects.cairo @@ -0,0 +1,61 @@ +from services.perpetual.cairo.definitions.constants import FUNDING_INDEX_LOWER_BOUND +from starkware.cairo.common.registers import get_fp_and_pc +from starkware.cairo.common.serialize import serialize_array, serialize_word + +struct FundingIndex: + member asset_id : felt + # funding_index in fxp 32.32 format. + member funding_index : felt +end + +func funding_index_serialize{output_ptr : felt*}(funding_index : FundingIndex*): + serialize_word(funding_index.asset_id) + serialize_word(funding_index.funding_index - FUNDING_INDEX_LOWER_BOUND) + return () +end + +# Funding indices and their timestamp. +struct FundingIndicesInfo: + member n_funding_indices : felt + member funding_indices : FundingIndex* + member funding_timestamp : felt +end + +func funding_indices_info_serialize{output_ptr : felt*}(funding_indices : FundingIndicesInfo*): + get_fp_and_pc() + let __pc__ = [fp + 1] + + ret_pc_label: + serialize_array( + array=cast(funding_indices.funding_indices, felt*), + n_elms=funding_indices.n_funding_indices, + elm_size=FundingIndex.SIZE, + callback=funding_index_serialize + __pc__ - ret_pc_label) + serialize_word(funding_indices.funding_timestamp) + return () +end + +# Represents a single asset's Oracle Price in internal representation (Refer to the documentation of +# AssetOraclePrice for the definition of internal representation). +struct OraclePrice: + member asset_id : felt + # 32.32 fixed point. + member price : felt +end + +# An array of oracle prices. +struct OraclePrices: + member len : felt + member data : OraclePrice* +end + +func oracle_prices_new(len, data : OraclePrice*) -> (oracle_prices : OraclePrices*): + let (fp_val, pc_val) = get_fp_and_pc() + return (oracle_prices=cast(fp_val - 2 - OraclePrices.SIZE, OraclePrices*)) +end + +func oracle_price_serialize{output_ptr : felt*}(oracle_price : OraclePrice*): + serialize_word(oracle_price.asset_id) + serialize_word(oracle_price.price) + return () +end diff --git a/src/services/perpetual/cairo/definitions/perpetual_error_code.cairo b/src/services/perpetual/cairo/definitions/perpetual_error_code.cairo new file mode 100644 index 0000000..344f3ac --- /dev/null +++ b/src/services/perpetual/cairo/definitions/perpetual_error_code.cairo @@ -0,0 +1,44 @@ +# An enum that lists all possible errors that could happen during the run. Useful for giving context +# upon error via hints. +# The error code for success is zero and all other error codes are positive. +namespace PerpetualErrorCode: + const SUCCESS = 0 + # An error code for unspecified errors. + const ERROR = 1 + const ILLEGAL_POSITION_TRANSITION_ENLARGING_SYNTHETIC_HOLDINGS = 2 + const ILLEGAL_POSITION_TRANSITION_NO_RISK_REDUCED_VALUE = 3 + const ILLEGAL_POSITION_TRANSITION_REDUCING_TOTAL_VALUE_RISK_RATIO = 4 + const INVALID_ASSET_ORACLE_PRICE = 5 + const INVALID_COLLATERAL_ASSET_ID = 6 + const INVALID_FULFILLMENT_ASSETS_RATIO = 7 + const INVALID_FULFILLMENT_FEE_RATIO = 8 + const INVALID_FULFILLMENT_INFO = 9 + const INVALID_FUNDING_TICK_TIMESTAMP = 10 + const INVALID_PUBLIC_KEY = 11 + const INVALID_SIGNATURE = 12 + const MISSING_GLOBAL_FUNDING_INDEX = 13 + const MISSING_ORACLE_PRICE = 14 + const MISSING_SYNTHETIC_ASSET_ID = 15 + const OUT_OF_RANGE_AMOUNT = 16 + const OUT_OF_RANGE_BALANCE = 17 + const OUT_OF_RANGE_FUNDING_INDEX = 18 + const OUT_OF_RANGE_POSITIVE_AMOUNT = 19 + const OUT_OF_RANGE_TOTAL_RISK = 20 + const OUT_OF_RANGE_TOTAL_VALUE = 21 + const SAME_POSITION_ID = 22 + const TOO_MANY_SYNTHETIC_ASSETS_IN_POSITION = 23 + const TOO_MANY_SYNTHETIC_ASSETS_IN_SYSTEM = 24 + const UNDELEVERAGABLE_POSITION = 25 + const UNFAIR_DELEVERAGE = 26 + const UNLIQUIDATABLE_POSITION = 27 + const UNSORTED_ORACLE_PRICES = 28 +end + +# Receives an error code and verifies it is equal to SUCCESS. +# If not, the function will put the error code in a hint variable before exiting. +func assert_success(error_code): + %{ error_code = ids.error_code %} + assert error_code = PerpetualErrorCode.SUCCESS + %{ del error_code %} + return () +end diff --git a/src/services/perpetual/cairo/execute_batch.cairo b/src/services/perpetual/cairo/execute_batch.cairo new file mode 100644 index 0000000..a45d7dd --- /dev/null +++ b/src/services/perpetual/cairo/execute_batch.cairo @@ -0,0 +1,363 @@ +from services.perpetual.cairo.definitions.general_config import SyntheticAssetInfo +from services.perpetual.cairo.definitions.objects import OraclePrice +from services.perpetual.cairo.definitions.perpetual_error_code import PerpetualErrorCode +from services.perpetual.cairo.execute_batch_utils import ( + validate_funding_indices_in_general_config, validate_general_config) +from services.perpetual.cairo.oracle.oracle_price import ( + TimeBounds, check_oracle_prices, signed_prices_to_prices) +from services.perpetual.cairo.output.program_input import ProgramInput +from services.perpetual.cairo.output.program_output import PerpetualOutputs, perpetual_outputs_empty +from services.perpetual.cairo.state.state import CarriedState +from services.perpetual.cairo.transactions.batch_config import BatchConfig, batch_config_new +from services.perpetual.cairo.transactions.conditional_transfer import ( + ConditionalTransfer, execute_conditional_transfer) +from services.perpetual.cairo.transactions.deleverage import Deleverage, execute_deleverage +from services.perpetual.cairo.transactions.deposit import Deposit, execute_deposit +from services.perpetual.cairo.transactions.forced_trade import ForcedTrade, execute_forced_trade +from services.perpetual.cairo.transactions.forced_withdrawal import ( + ForcedWithdrawal, execute_forced_withdrawal) +from services.perpetual.cairo.transactions.funding_tick import FundingTick, execute_funding_tick +from services.perpetual.cairo.transactions.liquidate import Liquidate, execute_liquidate +from services.perpetual.cairo.transactions.oracle_prices_tick import ( + OraclePricesTick, execute_oracle_prices_tick) +from services.perpetual.cairo.transactions.trade import Trade, execute_trade +from services.perpetual.cairo.transactions.transaction import ( + Transaction, Transactions, TransactionType) +from services.perpetual.cairo.transactions.transfer import Transfer, execute_transfer +from services.perpetual.cairo.transactions.withdrawal import Withdrawal, execute_withdrawal +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.math import assert_le, assert_lt +from starkware.cairo.common.registers import get_fp_and_pc + +func execute_transaction( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*, batch_config : BatchConfig*, + tx : Transaction*) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*): + local tx_type = tx.tx_type + alloc_locals + + if tx_type == TransactionType.ORACLE_PRICES_TICK: + let (pedersen_ptr, range_check_ptr, ecdsa_ptr, carried_state, + outputs) = execute_oracle_prices_tick( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + batch_config=batch_config, + outputs=outputs, + tx=cast(tx.tx, OraclePricesTick*)) + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) + end + + if tx_type == TransactionType.FUNDING_TICK: + let (pedersen_ptr, range_check_ptr, ecdsa_ptr, carried_state, + outputs) = execute_funding_tick( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + batch_config=batch_config, + outputs=outputs, + tx=cast(tx.tx, FundingTick*)) + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) + end + + # For every other transaction, we need to check that the funding timestamp is up to date + # with respect to the system time. + %{ error_code = ids.PerpetualErrorCode.INVALID_FUNDING_TICK_TIMESTAMP %} + assert_le{range_check_ptr=range_check_ptr}( + carried_state.system_time, + carried_state.global_funding_indices.funding_timestamp + + batch_config.general_config.timestamp_validation_config.funding_validity_period) + %{ del error_code %} + + if tx_type == TransactionType.TRADE: + let (pedersen_ptr, range_check_ptr, ecdsa_ptr, carried_state, outputs) = execute_trade( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + batch_config=batch_config, + outputs=outputs, + tx=cast(tx.tx, Trade*)) + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) + end + + if tx_type == TransactionType.DEPOSIT: + # Deposit. + let (pedersen_ptr, range_check_ptr, ecdsa_ptr, carried_state, outputs) = execute_deposit( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + batch_config=batch_config, + outputs=outputs, + tx=cast(tx.tx, Deposit*)) + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) + end + + if tx_type == TransactionType.TRANSFER: + let (pedersen_ptr, range_check_ptr, ecdsa_ptr, carried_state, outputs) = execute_transfer( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + batch_config=batch_config, + outputs=outputs, + tx=cast(tx.tx, Transfer*)) + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) + end + + if tx_type == TransactionType.CONDITIONAL_TRANSFER: + let (pedersen_ptr, range_check_ptr, ecdsa_ptr, carried_state, + outputs) = execute_conditional_transfer( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + batch_config=batch_config, + outputs=outputs, + tx=cast(tx.tx, ConditionalTransfer*)) + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) + end + + if tx_type == TransactionType.LIQUIDATE: + let (pedersen_ptr, range_check_ptr, ecdsa_ptr, carried_state, outputs) = execute_liquidate( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + batch_config=batch_config, + outputs=outputs, + tx=cast(tx.tx, Liquidate*)) + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) + end + + if tx_type == TransactionType.DELEVERAGE: + let (pedersen_ptr, range_check_ptr, ecdsa_ptr, carried_state, outputs) = execute_deleverage( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + batch_config=batch_config, + outputs=outputs, + tx=cast(tx.tx, Deleverage*)) + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) + end + + if tx_type == TransactionType.WITHDRAWAL: + let (pedersen_ptr, range_check_ptr, ecdsa_ptr, carried_state, outputs) = execute_withdrawal( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + batch_config=batch_config, + outputs=outputs, + tx=cast(tx.tx, Withdrawal*)) + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) + end + + if tx_type == TransactionType.FORCED_WITHDRAWAL: + let (pedersen_ptr, range_check_ptr, ecdsa_ptr, carried_state, + outputs) = execute_forced_withdrawal( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + batch_config=batch_config, + outputs=outputs, + tx=cast(tx.tx, ForcedWithdrawal*)) + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) + end + + if tx_type == TransactionType.FORCED_TRADE: + let (pedersen_ptr, range_check_ptr, ecdsa_ptr, carried_state, + outputs) = execute_forced_trade( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + batch_config=batch_config, + outputs=outputs, + tx=cast(tx.tx, ForcedTrade*)) + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) + end + + assert 1 = 0 + jmp rel 0 +end + +func execute_batch_transactions( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*, batch_config : BatchConfig*, + n_txs : felt, tx : Transaction*) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*): + if n_txs == 0: + # No transactions left. + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) + end + + let (pedersen_ptr, range_check_ptr, ecdsa_ptr, carried_state, outputs) = execute_transaction( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs, + batch_config=batch_config, + tx=tx) + + return execute_batch_transactions( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs, + batch_config=batch_config, + n_txs=n_txs - 1, + tx=tx + Transaction.SIZE) +end + +func execute_batch( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, program_input : ProgramInput*, outputs : PerpetualOutputs*, + txs : Transactions*, end_system_time) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*): + alloc_locals + let (local __fp__, _) = get_fp_and_pc() + + let (range_check_ptr) = validate_general_config( + range_check_ptr=range_check_ptr, general_config=program_input.general_config) + + # Time bound to check for oracle price signature timestamps. + local time_bounds : TimeBounds + assert time_bounds.min_time = ( + carried_state.system_time - + program_input.general_config.timestamp_validation_config.price_validity_period) + assert time_bounds.max_time = end_system_time + + # Validate minimal and maximal oracle price signatures. Refer to the documentation of + # OraclePricesTick for more details. + let (range_check_ptr, ecdsa_ptr, pedersen_ptr) = check_oracle_prices( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + hash_ptr=pedersen_ptr, + n_oracle_prices=program_input.n_signed_oracle_prices, + asset_oracle_prices=program_input.signed_min_oracle_prices, + time_bounds=&time_bounds, + general_config=program_input.general_config) + let (local range_check_ptr, local ecdsa_ptr, local pedersen_ptr) = check_oracle_prices( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + hash_ptr=pedersen_ptr, + n_oracle_prices=program_input.n_signed_oracle_prices, + asset_oracle_prices=program_input.signed_max_oracle_prices, + time_bounds=&time_bounds, + general_config=program_input.general_config) + + # Convert Signed prices to prices. + let (local signed_min_oracle_prices) = signed_prices_to_prices( + n_oracle_prices=program_input.n_signed_oracle_prices, + asset_oracle_prices=program_input.signed_min_oracle_prices) + let (signed_max_oracle_prices) = signed_prices_to_prices( + n_oracle_prices=program_input.n_signed_oracle_prices, + asset_oracle_prices=program_input.signed_max_oracle_prices) + + # Create BatchConfig. + let (batch_config : BatchConfig*) = batch_config_new( + general_config=program_input.general_config, + signed_min_oracle_prices=signed_min_oracle_prices, + signed_max_oracle_prices=signed_max_oracle_prices, + n_oracle_prices=program_input.n_signed_oracle_prices, + min_expiration_timestamp=program_input.minimum_expiration_timestamp) + + # Execute all txs. + let (local pedersen_ptr, local range_check_ptr, local ecdsa_ptr, local carried_state, + local outputs) = execute_batch_transactions( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs, + batch_config=batch_config, + n_txs=txs.len, + tx=txs.data) + + # Post batch validations. + validate_funding_indices_in_general_config( + global_funding_indices=carried_state.global_funding_indices, + general_config=program_input.general_config) + + assert carried_state.system_time = end_system_time + + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) +end diff --git a/src/services/perpetual/cairo/execute_batch_utils.cairo b/src/services/perpetual/cairo/execute_batch_utils.cairo new file mode 100644 index 0000000..303d19b --- /dev/null +++ b/src/services/perpetual/cairo/execute_batch_utils.cairo @@ -0,0 +1,111 @@ +from services.perpetual.cairo.definitions.constants import ( + ASSET_ID_UPPER_BOUND, ASSET_RESOLUTION_LOWER_BOUND, ASSET_RESOLUTION_UPPER_BOUND, + COLLATERAL_ASSET_ID_UPPER_BOUND, N_ASSETS_UPPER_BOUND, ORACLE_PRICE_QUORUM_LOWER_BOUND, + ORACLE_PRICE_QUORUM_UPPER_BOUND, RISK_FACTOR_LOWER_BOUND, RISK_FACTOR_UPPER_BOUND) +from services.perpetual.cairo.definitions.general_config import ( + CollateralAssetInfo, GeneralConfig, SyntheticAssetInfo) +from services.perpetual.cairo.definitions.objects import FundingIndicesInfo +from services.perpetual.cairo.definitions.perpetual_error_code import PerpetualErrorCode +from services.perpetual.cairo.position.funding import FundingIndex +from starkware.cairo.common.math import ( + assert_in_range, assert_le, assert_le_felt, assert_lt, assert_not_zero) + +func validate_funding_indices_in_general_config_inner( + funding_index : FundingIndex*, n_funding_indices, + synthetic_asset_info : SyntheticAssetInfo*, n_synthetic_assets_info): + if n_funding_indices == 0: + return () + end + assert_not_zero(n_synthetic_assets_info) + if funding_index.asset_id == synthetic_asset_info.asset_id: + # Found synthetic asset info. + validate_funding_indices_in_general_config_inner( + funding_index=funding_index + FundingIndex.SIZE, + n_funding_indices=n_funding_indices - 1, + synthetic_asset_info=synthetic_asset_info + SyntheticAssetInfo.SIZE, + n_synthetic_assets_info=n_synthetic_assets_info - 1) + return () + else: + # Skip synthetic asset info. + validate_funding_indices_in_general_config_inner( + funding_index=funding_index, + n_funding_indices=n_funding_indices, + synthetic_asset_info=synthetic_asset_info + SyntheticAssetInfo.SIZE, + n_synthetic_assets_info=n_synthetic_assets_info - 1) + return () + end +end + +# Validates that everey asset id in global_funding_indices is in general_config's +# synthetic asset info. +func validate_funding_indices_in_general_config( + global_funding_indices : FundingIndicesInfo*, general_config : GeneralConfig*): + validate_funding_indices_in_general_config_inner( + funding_index=global_funding_indices.funding_indices, + n_funding_indices=global_funding_indices.n_funding_indices, + synthetic_asset_info=general_config.synthetic_assets_info, + n_synthetic_assets_info=general_config.n_synthetic_assets_info) + return () +end + +func validate_assets_config_inner( + range_check_ptr, synthetic_assets_info_ptr : SyntheticAssetInfo*, n_synthetic_assets_info, + prev_asset_id) -> (range_check_ptr): + if n_synthetic_assets_info == 0: + assert_lt{range_check_ptr=range_check_ptr}(prev_asset_id, ASSET_ID_UPPER_BOUND) + return (range_check_ptr=range_check_ptr) + end + assert_lt{range_check_ptr=range_check_ptr}(prev_asset_id, synthetic_assets_info_ptr.asset_id) + + assert_in_range{range_check_ptr=range_check_ptr}( + synthetic_assets_info_ptr.risk_factor, RISK_FACTOR_LOWER_BOUND, RISK_FACTOR_UPPER_BOUND) + + assert_in_range{range_check_ptr=range_check_ptr}( + synthetic_assets_info_ptr.oracle_price_quorum, + ORACLE_PRICE_QUORUM_LOWER_BOUND, + ORACLE_PRICE_QUORUM_UPPER_BOUND) + + assert_in_range{range_check_ptr=range_check_ptr}( + synthetic_assets_info_ptr.resolution, + ASSET_RESOLUTION_LOWER_BOUND, + ASSET_RESOLUTION_UPPER_BOUND) + + return validate_assets_config_inner( + range_check_ptr=range_check_ptr, + synthetic_assets_info_ptr=synthetic_assets_info_ptr + SyntheticAssetInfo.SIZE, + n_synthetic_assets_info=n_synthetic_assets_info - 1, + prev_asset_id=synthetic_assets_info_ptr.asset_id) +end + +# Validates that the synthetic assets info in general_config is sorted according to asset_id and +# that their risk factor is in range. +func validate_assets_config(range_check_ptr, general_config : GeneralConfig*) -> (range_check_ptr): + return validate_assets_config_inner( + range_check_ptr=range_check_ptr, + synthetic_assets_info_ptr=general_config.synthetic_assets_info, + n_synthetic_assets_info=general_config.n_synthetic_assets_info, + prev_asset_id=-1) +end + +# Validates that all the fields in general config are in range and that the synthetic assets info is +# sorted according to asset_id. +func validate_general_config(range_check_ptr, general_config : GeneralConfig*) -> (range_check_ptr): + let (range_check_ptr) = validate_assets_config( + range_check_ptr=range_check_ptr, general_config=general_config) + + tempvar collateral_asset_info : CollateralAssetInfo* = general_config.collateral_asset_info + + assert_le_felt{range_check_ptr=range_check_ptr}( + collateral_asset_info.asset_id, COLLATERAL_ASSET_ID_UPPER_BOUND - 1) + + assert_in_range{range_check_ptr=range_check_ptr}( + collateral_asset_info.resolution, + ASSET_RESOLUTION_LOWER_BOUND, + ASSET_RESOLUTION_UPPER_BOUND) + + %{ error_code = ids.PerpetualErrorCode.TOO_MANY_SYNTHETIC_ASSETS_IN_SYSTEM %} + assert_le{range_check_ptr=range_check_ptr}( + general_config.n_synthetic_assets_info, N_ASSETS_UPPER_BOUND) + %{ del error_code %} + return (range_check_ptr=range_check_ptr) +end diff --git a/src/services/perpetual/cairo/main.cairo b/src/services/perpetual/cairo/main.cairo new file mode 100644 index 0000000..f007dcf --- /dev/null +++ b/src/services/perpetual/cairo/main.cairo @@ -0,0 +1,160 @@ +%builtins output pedersen range_check ecdsa + +from services.perpetual.cairo.definitions.general_config_hash import ( + general_config_hash, general_config_hash_synthetic_assets) +from services.perpetual.cairo.execute_batch import execute_batch +from services.perpetual.cairo.output.data_availability import output_availability_data +from services.perpetual.cairo.output.forced import ForcedAction +from services.perpetual.cairo.output.program_input import ProgramInput +from services.perpetual.cairo.output.program_output import ( + Modification, PerpetualOutputs, ProgramOutput, perpetual_outputs_empty, program_output_new, + program_output_serialize) +from services.perpetual.cairo.state.state import ( + CarriedState, SharedState, SquashedCarriedState, carried_state_squash, + shared_state_apply_state_updates, shared_state_serialize, shared_state_to_carried_state) +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.registers import get_fp_and_pc + +# Hint argument: +# program_input - An object that has the following fields: +# program_input_struct - The fields of ProgramInput. +# positions_dict - A dictionary from position id to position. +# orders_dict - A dictionary from order id to order state. +# max_n_words_per_memory_page - Amount of words that can fit in a memory page. +# merkle_facts - A dictionary from the hash value of a merkle node to the pair of children values. +func main( + output_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr, + ecdsa_ptr : SignatureBuiltin*) -> ( + output_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr, + ecdsa_ptr : SignatureBuiltin*): + alloc_locals + let (local __fp__, _) = get_fp_and_pc() + local program_input : ProgramInput + + %{ + # Initialize program input and hint variables. + segments.write_arg(ids.program_input.address_, program_input['program_input_struct']) + positions_dict = {int(x): y for x,y in program_input['positions_dict'].items()} + orders_dict = {int(x): y for x,y in program_input['orders_dict'].items()} + max_n_words_per_memory_page = program_input['max_n_words_per_memory_page'] + + def as_int(x): + return int(x, 16) + preimage = { + as_int(root): (as_int(left_child), as_int(right_child)) + for root, (left_child, right_child) in program_input['merkle_facts'].items() + } + %} + let (local initial_carried_state) = shared_state_to_carried_state( + program_input.prev_shared_state) + + # Execute batch. + let (local outputs_start) = perpetual_outputs_empty() + let (local pedersen_ptr, range_check_ptr, local ecdsa_ptr, carried_state : CarriedState*, + local outputs) = execute_batch( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=initial_carried_state, + program_input=&program_input, + outputs=outputs_start, + txs=program_input.txs, + end_system_time=program_input.new_shared_state.system_time) + + # Get updated shared state. + with range_check_ptr: + let (squashed_carried_state) = carried_state_squash( + initial_carried_state=initial_carried_state, carried_state=carried_state) + end + local range_check_ptr = range_check_ptr + local squashed_carried_state : SquashedCarriedState* = squashed_carried_state + + let positions_root = program_input.new_shared_state.positions_root + let orders_root = program_input.new_shared_state.orders_root + %{ + new_positions_root = ids.positions_root + new_orders_root = ids.orders_root + %} + let (pedersen_ptr, local new_shared_state) = shared_state_apply_state_updates( + hash_ptr=pedersen_ptr, + shared_state=program_input.prev_shared_state, + squashed_carried_state=squashed_carried_state, + general_config=program_input.general_config) + + # Write public output. + with pedersen_ptr: + let (n_asset_configs, asset_configs) = general_config_hash_synthetic_assets( + general_config_ptr=program_input.general_config) + let (general_config_hash_value) = general_config_hash( + general_config_ptr=program_input.general_config) + end + local pedersen_ptr : HashBuiltin* = pedersen_ptr + let (program_output : ProgramOutput*) = program_output_new( + general_config_hash=general_config_hash_value, + n_asset_configs=n_asset_configs, + asset_configs=asset_configs, + prev_shared_state=program_input.prev_shared_state, + new_shared_state=new_shared_state, + minimum_expiration_timestamp=program_input.minimum_expiration_timestamp, + n_modifications=( + outputs.modifications_ptr - outputs_start.modifications_ptr) / Modification.SIZE, + modifications=outputs_start.modifications_ptr, + n_forced_actions=( + outputs.forced_actions_ptr - outputs_start.forced_actions_ptr) / ForcedAction.SIZE, + forced_actions=outputs_start.forced_actions_ptr, + n_conditions=outputs.conditions_ptr - outputs_start.conditions_ptr, + conditions=outputs_start.conditions_ptr) + + with output_ptr: + program_output_serialize(program_output=program_output) + end + + %{ onchain_data_start = ids.output_ptr %} + + let (range_check_ptr, output_ptr) = output_availability_data( + range_check_ptr=range_check_ptr, + output_ptr=output_ptr, + squashed_state=squashed_carried_state, + perpetual_outputs_start=outputs_start, + perpetual_outputs_end=outputs) + + %{ + from starkware.python.math_utils import div_ceil + onchain_data_size = ids.output_ptr - onchain_data_start + assert onchain_data_size > 0, 'Empty onchain data is not supported.' + + # Split the output into pages. + n_pages = div_ceil(onchain_data_size, max_n_words_per_memory_page) + for i in range(n_pages): + start_offset = i * max_n_words_per_memory_page + output_builtin.add_page( + page_id=1 + i, + page_start=onchain_data_start + start_offset, + page_size=min(onchain_data_size - start_offset, max_n_words_per_memory_page), + ) + + # Set the tree structure to a root with two children: + # * A leaf which represents the main part + # * An inner node for the onchain data part (which contains n_pages children). + # + # This is encoded using the following sequence: + output_builtin.add_attribute('gps_fact_topology', [ + # Push 1 + n_pages pages (all of the pages). + 1 + n_pages, + # Create a parent node for the last n_pages. + n_pages, + # Don't push additional pages. + 0, + # Take the first page (the main part) and the node that was created (onchain data) + # and use them to construct the root of the fact tree. + 2, + ]) + %} + + return ( + output_ptr=output_ptr, + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr) +end diff --git a/src/services/perpetual/cairo/oracle/oracle_price.cairo b/src/services/perpetual/cairo/oracle/oracle_price.cairo new file mode 100644 index 0000000..52b51dc --- /dev/null +++ b/src/services/perpetual/cairo/oracle/oracle_price.cairo @@ -0,0 +1,329 @@ +from services.perpetual.cairo.definitions.constants import ( + EXTERNAL_PRICE_FIXED_POINT_UNIT, EXTERNAL_PRICE_UPPER_BOUND, FXP_32_ONE, PRICE_LOWER_BOUND, + PRICE_UPPER_BOUND) +from services.perpetual.cairo.definitions.general_config import ( + CollateralAssetInfo, GeneralConfig, SyntheticAssetInfo) +from services.perpetual.cairo.definitions.objects import OraclePrice +from services.perpetual.cairo.definitions.perpetual_error_code import PerpetualErrorCode +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.find_element import find_element +from starkware.cairo.common.hash import hash2 +from starkware.cairo.common.math import ( + assert_in_range, assert_le, assert_lt_felt, assert_nn_le, assert_not_zero, sign, + unsigned_div_rem) +from starkware.cairo.common.signature import verify_ecdsa_signature + +# Price definitions: +# An external price is a unit of the collateral asset divided by a unit of synthetic asset. +# An internal price is computed as the ratio between a unit of collateral asset and its resolution, +# divided by the ratio between a unit of synthetic asset and its resolution: +# (collateral_asset_unit / collateral_resolution) / (synthetic_asset_unit / synthetic_resolution). + +# Represents a single signature on an external price with a timestamp. +struct SignedOraclePrice: + member signer_key : felt + member external_price : felt + member timestamp : felt + member signed_asset_id : felt + member signature_r : felt + member signature_s : felt +end + +# Represents a single Oracle Price of an asset in internal representation and +# signatures on that price. The price is a median of all prices in the signatures. +struct AssetOraclePrice: + member asset_id : felt + member price : felt + member n_signed_prices : felt + # Oracle signatures, sorted by signer_key. + member signed_prices : SignedOraclePrice* +end + +struct TimeBounds: + member min_time : felt + member max_time : felt +end + +const TIMESTAMP_BOUND = %[2**32%] + +# Checks a single price signature. +# * Signature is valid. +# * Signer public key is present in the SyntheticAssetInfo. +# * Signer asset id is present in the SyntheticAssetInfo. +# * Valid timestamp. +# * Signer key is greater than the last signer key (for uniqueness). +# Returns (is_le, is_ge) with respect to the median price. This is needed to verify the median. +func check_price_signature( + range_check_ptr, ecdsa_ptr : SignatureBuiltin*, hash_ptr : HashBuiltin*, + time_bounds : TimeBounds*, asset_info : SyntheticAssetInfo*, median_price, + collateral_resolution, sig : SignedOraclePrice*) -> ( + range_check_ptr, ecdsa_ptr : SignatureBuiltin*, hash_ptr : HashBuiltin*, is_le, is_ge): + alloc_locals + + # Check ranges. + assert_nn_le{range_check_ptr=range_check_ptr}(sig.external_price, EXTERNAL_PRICE_UPPER_BOUND) + assert_nn_le{range_check_ptr=range_check_ptr}(sig.timestamp, TIMESTAMP_BOUND) + + # Compute message. + with hash_ptr: + let (message) = hash2( + x=sig.signed_asset_id, y=sig.external_price * TIMESTAMP_BOUND + sig.timestamp) + end + local hash_ptr : HashBuiltin* = hash_ptr + + # Check signature. + with ecdsa_ptr: + verify_ecdsa_signature( + message=message, + public_key=sig.signer_key, + signature_r=sig.signature_r, + signature_s=sig.signature_s) + end + local ecdsa_ptr : SignatureBuiltin* = ecdsa_ptr + + # Check that signer is in the config. + %{ error_code = ids.PerpetualErrorCode.INVALID_ASSET_ORACLE_PRICE %} + find_element{range_check_ptr=range_check_ptr}( + array_ptr=asset_info.oracle_price_signers, + elm_size=1, + n_elms=asset_info.n_oracle_price_signers, + key=sig.signer_key) + + %{ del error_code %} + # Check that signed_asset_id is in the config. + find_element{range_check_ptr=range_check_ptr}( + array_ptr=asset_info.oracle_price_signed_asset_ids, + elm_size=1, + n_elms=asset_info.n_oracle_price_signed_asset_ids, + key=sig.signed_asset_id) + + # Check timestamp. + assert_in_range{range_check_ptr=range_check_ptr}( + sig.timestamp, time_bounds.min_time, time_bounds.max_time + 1) + + # Transform to internal price. + # price is a 32.32 bit fixed point number in internal asset units. + # signed prices are fixed point in external asset units. + # external_price_repr = external_coll / external_synth * EXTERNAL_PRICE_FIXED_POINT_UNIT + # internal_price_repr = internal_coll / internal_synth * FXP_32_ONE + # = (external_coll * res_coll) / (external_synth * res_synth) * FXP_32_ONE + # = external_price_repr * res_coll * FXP_32_ONE / + # (res_synth * EXTERNAL_PRICE_FIXED_POINT_UNIT). + # Assuming resolutions are 64bit. + # numerator is 192 bit. + let numerator = sig.external_price * collateral_resolution * FXP_32_ONE + # denominator is 96 bit. + tempvar denominator = asset_info.resolution * EXTERNAL_PRICE_FIXED_POINT_UNIT + # Add denominator/2 to round. + let (internal_price, _) = unsigned_div_rem{range_check_ptr=range_check_ptr}( + numerator + denominator / 2, denominator) + + # Check above or below median. + let (median_comparison) = sign{range_check_ptr=range_check_ptr}( + value=median_price - internal_price) + + if median_comparison == 0: + return ( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + hash_ptr=hash_ptr, + is_le=1, + is_ge=1) + end + # If median_comparison is 1, is_ge will be 1. If median_comparison is -1, is_ge will be 0. + tempvar is_ge = (median_comparison + 1) / 2 + return ( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + hash_ptr=hash_ptr, + is_le=1 - is_ge, + is_ge=is_ge) +end + +func check_oracle_price_inner( + range_check_ptr, ecdsa_ptr : SignatureBuiltin*, hash_ptr : HashBuiltin*, + time_bounds : TimeBounds*, asset_info : SyntheticAssetInfo*, median_price, + collateral_resolution, sig : SignedOraclePrice*, n_sigs, last_signer, n_le, n_ge) -> ( + range_check_ptr, ecdsa_ptr : SignatureBuiltin*, hash_ptr : HashBuiltin*, n_le, n_ge): + if n_sigs == 0: + # All signatures are checked. + return ( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + hash_ptr=hash_ptr, + n_le=n_le, + n_ge=n_ge) + end + + # Check that signer_key is greater than the last signer key. This assures uniqueness of signers. + assert_lt_felt{range_check_ptr=range_check_ptr}(last_signer, sig.signer_key) + + # Check the signature. + let (range_check_ptr, ecdsa_ptr, hash_ptr, is_le, is_ge) = check_price_signature( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + hash_ptr=hash_ptr, + time_bounds=time_bounds, + asset_info=asset_info, + median_price=median_price, + collateral_resolution=collateral_resolution, + sig=sig) + + # Recursive call. + return check_oracle_price_inner( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + hash_ptr=hash_ptr, + time_bounds=time_bounds, + asset_info=asset_info, + median_price=median_price, + collateral_resolution=collateral_resolution, + sig=sig + SignedOraclePrice.SIZE, + n_sigs=n_sigs - 1, + last_signer=sig.signer_key, + n_le=n_le + is_le, + n_ge=n_ge + is_ge) +end + +# Checks the validity of a single oracle price given as AssetOraclePrice. +# Checks there are at least quorum valid signatures from distinct signer keys, and that the price +# used is a median price of these signed prices. +func check_oracle_price( + range_check_ptr, ecdsa_ptr : SignatureBuiltin*, hash_ptr : HashBuiltin*, + time_bounds : TimeBounds*, asset_oracle_price : AssetOraclePrice*, + asset_info : SyntheticAssetInfo*, collateral_info : CollateralAssetInfo*) -> ( + range_check_ptr, ecdsa_ptr : SignatureBuiltin*, hash_ptr : HashBuiltin*): + alloc_locals + local n_sigs = asset_oracle_price.n_signed_prices + + # Check that we have enough signatures (>= quorum). + assert_le{range_check_ptr=range_check_ptr}(asset_info.oracle_price_quorum, n_sigs) + + # Check that price is in range. + assert_in_range{range_check_ptr=range_check_ptr}( + asset_oracle_price.price, PRICE_LOWER_BOUND, PRICE_UPPER_BOUND) + + # Check all signatures. + let (range_check_ptr, ecdsa_ptr, hash_ptr, n_le, n_ge) = check_oracle_price_inner( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + hash_ptr=hash_ptr, + time_bounds=time_bounds, + asset_info=asset_info, + median_price=asset_oracle_price.price, + collateral_resolution=collateral_info.resolution, + sig=asset_oracle_price.signed_prices, + n_sigs=n_sigs, + last_signer=0, + n_le=0, + n_ge=0) + + # Check that the median price is indeed a median: + # At least half the oracle prices are greater or equal to the median price and + # at least half the oracle prices are smaller or equal to the median price. + assert_le{range_check_ptr=range_check_ptr}(n_sigs, n_le * 2) + assert_le{range_check_ptr=range_check_ptr}(n_sigs, n_ge * 2) + + return (range_check_ptr=range_check_ptr, ecdsa_ptr=ecdsa_ptr, hash_ptr=hash_ptr) +end + +func check_oracle_prices_inner( + range_check_ptr, ecdsa_ptr : SignatureBuiltin*, hash_ptr : HashBuiltin*, n_oracle_prices, + asset_oracle_prices : AssetOraclePrice*, n_synthetic_assets_info, + synthetic_assets_info : SyntheticAssetInfo*, time_bounds : TimeBounds*, + general_config : GeneralConfig*) -> ( + range_check_ptr, ecdsa_ptr : SignatureBuiltin*, hash_ptr : HashBuiltin*): + if n_oracle_prices == 0: + # All prices are validated. + return (range_check_ptr=range_check_ptr, ecdsa_ptr=ecdsa_ptr, hash_ptr=hash_ptr) + end + + # n_synthetic_assets_info = 0 means that the current asset was not found in the general config. + %{ error_code = ids.PerpetualErrorCode.MISSING_SYNTHETIC_ASSET_ID %} + assert_not_zero(n_synthetic_assets_info) + %{ del error_code %} + + if asset_oracle_prices.asset_id != synthetic_assets_info.asset_id: + # Advance synthetic_assets_info until we get to our asset. + return check_oracle_prices_inner( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + hash_ptr=hash_ptr, + n_oracle_prices=n_oracle_prices, + asset_oracle_prices=asset_oracle_prices, + n_synthetic_assets_info=n_synthetic_assets_info - 1, + synthetic_assets_info=synthetic_assets_info + SyntheticAssetInfo.SIZE, + time_bounds=time_bounds, + general_config=general_config) + end + + # Check this oracle price. + let (range_check_ptr, ecdsa_ptr, hash_ptr) = check_oracle_price( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + hash_ptr=hash_ptr, + time_bounds=time_bounds, + asset_oracle_price=asset_oracle_prices, + asset_info=synthetic_assets_info, + collateral_info=general_config.collateral_asset_info) + + # Recursive call. + return check_oracle_prices_inner( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + hash_ptr=hash_ptr, + n_oracle_prices=n_oracle_prices - 1, + asset_oracle_prices=asset_oracle_prices + AssetOraclePrice.SIZE, + n_synthetic_assets_info=n_synthetic_assets_info - 1, + synthetic_assets_info=synthetic_assets_info + SyntheticAssetInfo.SIZE, + time_bounds=time_bounds, + general_config=general_config) +end + +# Checks that a list of AssetOraclePrice instances are valid with respect to a GeneralConfig and a +# time frame. +func check_oracle_prices( + range_check_ptr, ecdsa_ptr : SignatureBuiltin*, hash_ptr : HashBuiltin*, n_oracle_prices, + asset_oracle_prices : AssetOraclePrice*, time_bounds : TimeBounds*, + general_config : GeneralConfig*) -> ( + range_check_ptr, ecdsa_ptr : SignatureBuiltin*, hash_ptr : HashBuiltin*): + return check_oracle_prices_inner( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + hash_ptr=hash_ptr, + n_oracle_prices=n_oracle_prices, + asset_oracle_prices=asset_oracle_prices, + n_synthetic_assets_info=general_config.n_synthetic_assets_info, + synthetic_assets_info=general_config.synthetic_assets_info, + time_bounds=time_bounds, + general_config=general_config) +end + +func signed_prices_to_price_inner( + n_oracle_prices, asset_oracle_prices : AssetOraclePrice*, oracle_prices : OraclePrice*): + if n_oracle_prices == 0: + return () + end + + assert oracle_prices.asset_id = asset_oracle_prices.asset_id + assert oracle_prices.price = asset_oracle_prices.price + + signed_prices_to_price_inner( + n_oracle_prices=n_oracle_prices - 1, + asset_oracle_prices=asset_oracle_prices + AssetOraclePrice.SIZE, + oracle_prices=oracle_prices + OraclePrice.SIZE) + return () +end + +# Converts signed oracle prices (AssetOraclePrice*) to oracle prices (OraclePrice*). +func signed_prices_to_prices(n_oracle_prices, asset_oracle_prices : AssetOraclePrice*) -> ( + oracle_prices : OraclePrice*): + alloc_locals + let (local oracle_prices : OraclePrice*) = alloc() + signed_prices_to_price_inner( + n_oracle_prices=n_oracle_prices, + asset_oracle_prices=asset_oracle_prices, + oracle_prices=oracle_prices) + return (oracle_prices) +end diff --git a/src/services/perpetual/cairo/order/limit_order.cairo b/src/services/perpetual/cairo/order/limit_order.cairo new file mode 100644 index 0000000..c7160e4 --- /dev/null +++ b/src/services/perpetual/cairo/order/limit_order.cairo @@ -0,0 +1,67 @@ +from services.exchange.cairo.definitions.constants import VAULT_ID_UPPER_BOUND +from services.exchange.cairo.order import OrderBase +from services.exchange.cairo.signature_message_hashes import ExchangeLimitOrder +from services.exchange.cairo.signature_message_hashes import ( + limit_order_hash as exchange_limit_order_hash) +from services.perpetual.cairo.definitions.constants import POSITION_ID_UPPER_BOUND +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.hash import hash2 + +struct LimitOrder: + member base : OrderBase* + member amount_synthetic : felt + member amount_collateral : felt + member amount_fee : felt + member asset_id_synthetic : felt + member asset_id_collateral : felt + member position_id : felt + member is_buying_synthetic : felt +end + +# limit_order_hash: +# Computes the hash of a limit order. +# +# The hash is defined as h(h(h(h(w1, w2), w3), w4), w5) where h is the +# starkware pedersen function and w1,...w5 are as follows: +# w1= token_sell +# w2= token_buy +# w3= token_fee +# w4= amount_sell (64 bit) || amount_buy (64 bit) || amount_fee (64 bit) || nonce (32 bit) +# w5= 0x3 (10 bit) || vault_fee_src (64 bit) || vault_sell (64 bit) || vault_buy (64 bit) +# || expiration_timestamp (32 bit) || 0 (17 bit) +# +# Assumptions (bounds defined in services.perpetual.cairo.definitions.constants): +# amount_sell < AMOUNT_UPPER_BOUND +# amount_buy < AMOUNT_UPPER_BOUND +# amount_fee < AMOUNT_UPPER_BOUND +# nonce < NONCE_UPPER_BOUND +# position_id < POSITION_ID_UPPER_BOUND +# expiration_timestamp < EXPIRATION_TIMESTAMP_UPPER_BOUND. +func limit_order_hash{pedersen_ptr : HashBuiltin*}(limit_order : LimitOrder*) -> (limit_order_hash): + alloc_locals + static_assert POSITION_ID_UPPER_BOUND == VAULT_ID_UPPER_BOUND + + let (local exchange_limit_order : ExchangeLimitOrder*) = alloc() + assert exchange_limit_order.base = limit_order.base + assert exchange_limit_order.amount_fee = limit_order.amount_fee + assert exchange_limit_order.asset_id_fee = limit_order.asset_id_collateral + assert exchange_limit_order.vault_buy = limit_order.position_id + assert exchange_limit_order.vault_sell = limit_order.position_id + assert exchange_limit_order.vault_fee = limit_order.position_id + + if limit_order.is_buying_synthetic != 0: + assert exchange_limit_order.asset_id_sell = limit_order.asset_id_collateral + assert exchange_limit_order.asset_id_buy = limit_order.asset_id_synthetic + assert exchange_limit_order.amount_sell = limit_order.amount_collateral + assert exchange_limit_order.amount_buy = limit_order.amount_synthetic + else: + assert exchange_limit_order.asset_id_sell = limit_order.asset_id_synthetic + assert exchange_limit_order.asset_id_buy = limit_order.asset_id_collateral + assert exchange_limit_order.amount_sell = limit_order.amount_synthetic + assert exchange_limit_order.amount_buy = limit_order.amount_collateral + end + + let (limit_order_hash) = exchange_limit_order_hash(limit_order=exchange_limit_order) + return (limit_order_hash=limit_order_hash) +end diff --git a/src/services/perpetual/cairo/order/order.cairo b/src/services/perpetual/cairo/order/order.cairo new file mode 100644 index 0000000..2977971 --- /dev/null +++ b/src/services/perpetual/cairo/order/order.cairo @@ -0,0 +1,135 @@ +from services.exchange.cairo.order import OrderBase +from services.perpetual.cairo.definitions.constants import ( + AMOUNT_UPPER_BOUND, EXPIRATION_TIMESTAMP_UPPER_BOUND, NONCE_UPPER_BOUND, ORDER_ID_UPPER_BOUND, + POSITION_ID_UPPER_BOUND, POSITIVE_AMOUNT_LOWER_BOUND, RANGE_CHECK_BOUND, SIGNED_MESSAGE_BOUND) +from services.perpetual.cairo.definitions.perpetual_error_code import PerpetualErrorCode +from starkware.cairo.common.cairo_builtins import SignatureBuiltin +from starkware.cairo.common.dict import dict_update +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.math import assert_in_range, assert_le, assert_nn, assert_nn_le +from starkware.cairo.common.signature import verify_ecdsa_signature + +# Extracts the order_id from the message_hash. +# The order_id is represented by the 64 most significant bits of the message_hash. +# +# Assumptions: +# The caller checks that 0 <= order_id < ORDER_ID_UPPER_BOUND. +# 0 <= message_hash < SIGNED_MESSAGE_BOUND. +func extract_order_id(range_check_ptr, message_hash) -> (range_check_ptr, order_id): + # The 251-bit message_hash can be viewed as a packing of three fields: + # +----------------+--------------------+----------------LSB-+ + # | order_id (64b) | middle_field (59b) | right_field (128b) | + # +----------------+--------------------+--------------------+ + # . + const ORDER_ID_SHIFT = SIGNED_MESSAGE_BOUND / ORDER_ID_UPPER_BOUND + const MIDDLE_FIELD_BOUND = ORDER_ID_SHIFT / RANGE_CHECK_BOUND + + # Local variables. + local middle_field + local right_field + local order_id + %{ + msg_hash = ids.message_hash + ids.order_id = msg_hash // ids.ORDER_ID_SHIFT + ids.right_field = msg_hash & (ids.RANGE_CHECK_BOUND - 1) + ids.middle_field = (msg_hash // ids.RANGE_CHECK_BOUND) & (ids.MIDDLE_FIELD_BOUND - 1) + assert ids.MIDDLE_FIELD_BOUND & (ids.MIDDLE_FIELD_BOUND - 1) == 0, \ + f'MIDDLE_FIELD_BOUND should be a power of 2' + %} + alloc_locals + + # Verify that the message_hash definition holds, i.e., that: + assert message_hash = order_id * ORDER_ID_SHIFT + middle_field * RANGE_CHECK_BOUND + right_field + + # Verify the message_hash structure (i.e., the size of each field), to ensure unique unpacking. + # Note that the size of order_id is verified by performing merkle_update on the order tree. + # Check that 0 <= right_field < RANGE_CHECK_BOUND. + assert_nn{range_check_ptr=range_check_ptr}(right_field) + + # Check that 0 <= middle_field < MIDDLE_FIELD_BOUND. + assert_nn_le{range_check_ptr=range_check_ptr}(middle_field, MIDDLE_FIELD_BOUND - 1) + + return (range_check_ptr=range_check_ptr, order_id=order_id) +end + +# Updates the fulfillment amount of a user order to prevent replays. +# Extracts the order_id from the message_hash. +# Checks that update_amount does not exceed the order_capacity (= full_amount - fulfilled_amount). +# And updates the fulfilled_amount to reflect that 'update_amount' units were consumed. +# +# Checks that update_amount and full_amount are in the range [0, AMOUNT_UPPER_BOUND) +# +# Arguments: +# range_check_ptr - range check builtin pointer. +# orders_dict - a pointer to the orders dict. +# message_hash - The hash of the order. +# update_amount - The amount to add to the current amount in the order tree. +# full_amount - The full in the user order, the order may not exceed this amount. +# +# Assumption: +# The amounts in the orders_dict are non-negative. +func update_order_fulfillment( + range_check_ptr, orders_dict : DictAccess*, message_hash, update_amount, full_amount) -> ( + range_check_ptr, orders_dict : DictAccess*): + alloc_locals + + # Note that by using order_id to access the order_dict we check that + # 0 <= order_id < 2**ORDER_TREE_HEIGHT = ORDER_ID_UPPER_BOUND. + let (range_check_ptr, order_id) = extract_order_id( + range_check_ptr=range_check_ptr, message_hash=message_hash) + + local fulfilled_amount + %{ + ids.fulfilled_amount = __dict_manager.get_dict(ids.orders_dict)[ids.order_id] + # Prepare error_code in case of error. This won't affect the cairo logic. + if ids.update_amount > ids.remaining_capacity: + error_code = ids.PerpetualErrorCode.INVALID_FULFILLMENT_INFO + else: + # If there's an error in this case, then it's because update_amount is negative. + error_code = ids.PerpetualErrorCode.OUT_OF_RANGE_AMOUNT + %} + let remaining_capacity = full_amount - fulfilled_amount + + # Check that 0 <= update_amount <= full_amount - fulfilled_amount. + # Note that we may have remaining_capacity < 0 in the case of a collision in the order_id. + assert_nn_le{range_check_ptr=range_check_ptr}(update_amount, remaining_capacity) + %{ error_code = ids.PerpetualErrorCode.OUT_OF_RANGE_AMOUNT %} + # Check that full_amount < AMOUNT_UPPER_BOUND. + assert_le{range_check_ptr=range_check_ptr}(full_amount, AMOUNT_UPPER_BOUND - 1) + %{ del error_code %} + + dict_update{dict_ptr=orders_dict}( + key=order_id, prev_value=fulfilled_amount, new_value=fulfilled_amount + update_amount) + + return (range_check_ptr=range_check_ptr, orders_dict=orders_dict) +end + +# Does the generic book keeping for a user signed order (limit_order, withdrawal, transfer, etc.). +# Checks the signature, the expiration_timestamp and calls update_order_fulfillment. +# The caller is responsible for the order specific logic. I.e., updating the positions dict. +func validate_order_and_update_fulfillment( + range_check_ptr, ecdsa_ptr : SignatureBuiltin*, orders_dict : DictAccess*, message_hash, + order : OrderBase*, min_expiration_timestamp, update_amount, full_amount) -> ( + range_check_ptr, ecdsa_ptr : SignatureBuiltin*, orders_dict : DictAccess*): + %{ error_code = ids.PerpetualErrorCode.INVALID_SIGNATURE %} + with ecdsa_ptr: + verify_ecdsa_signature( + message=message_hash, + public_key=order.public_key, + signature_r=order.signature_r, + signature_s=order.signature_s) + end + %{ del error_code %} + assert_in_range{range_check_ptr=range_check_ptr}( + order.expiration_timestamp, min_expiration_timestamp, EXPIRATION_TIMESTAMP_UPPER_BOUND) + assert_nn_le{range_check_ptr=range_check_ptr}(order.nonce, NONCE_UPPER_BOUND - 1) + + let (range_check_ptr, orders_dict : DictAccess*) = update_order_fulfillment( + range_check_ptr=range_check_ptr, + orders_dict=orders_dict, + message_hash=message_hash, + update_amount=update_amount, + full_amount=full_amount) + + return (range_check_ptr=range_check_ptr, ecdsa_ptr=ecdsa_ptr, orders_dict=orders_dict) +end diff --git a/src/services/perpetual/cairo/order/validate_limit_order.cairo b/src/services/perpetual/cairo/order/validate_limit_order.cairo new file mode 100644 index 0000000..fdf9384 --- /dev/null +++ b/src/services/perpetual/cairo/order/validate_limit_order.cairo @@ -0,0 +1,75 @@ +from services.perpetual.cairo.definitions.perpetual_error_code import PerpetualErrorCode +from services.perpetual.cairo.order.limit_order import LimitOrder +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.math import assert_le, assert_lt, assert_nn + +# Validates limit order fulfillment. +# Asserts that the actual amounts are fair with respect to the signed limit order. +# +# Checks that the party will not lose on the ratio between buying and selling and will not be +# charged more fee, with respect to the collateral, than was signed on in the order. +# I.e., +# actual_fee / actual_collaterall <= amount_fee / amount_collateral +# actual_sold / actual_bought <= amount_sell / amount_buy (with the relaxation below). +# +# We relax the inequality, to allow rounding of actual_collateral to either side. +# This allows to always produce valid actual_amounts for opposing orders, +# even when the amounts are very small. +# +# Assumption: +# 1 <= actual_synthetic <= order.amount_synthetic < AMOUNT_UPPER_BOUND +# 0 <= actual_collateral < AMOUNT_UPPER_BOUND +# 0 <= actual_fee < AMOUNT_UPPER_BOUND +# 0 <= order.amount_collateral < AMOUNT_UPPER_BOUND +# 0 <= order.amount_fee < AMOUNT_UPPER_BOUND +# AMOUNT_UPPER_BOUND**2 <= rc_bound. +func validate_limit_order_fairness( + range_check_ptr, limit_order : LimitOrder*, actual_collateral, actual_synthetic, + actual_fee) -> (range_check_ptr): + tempvar amount_collateral = limit_order.amount_collateral + + # The party won't be charged more fee, with respect to the synthetic asset, than was signed on + # in the limit_order: + # actual_fee / actual_collateral <= amount_fee / amount_collateral, thus + # actual_fee * amount_collateral <= amount_fee * actual_collateral. + %{ error_code = ids.PerpetualErrorCode.INVALID_FULFILLMENT_FEE_RATIO %} + assert_le{range_check_ptr=range_check_ptr}( + actual_fee * amount_collateral, limit_order.amount_fee * actual_collateral) + %{ del error_code %} + + if limit_order.is_buying_synthetic != 0: + # Buying synthetic. + let actual_sold = actual_collateral + let actual_bought = actual_synthetic + let amount_sell = amount_collateral + let amount_buy = limit_order.amount_synthetic + + if actual_sold == 0: + # The buyer is always willing to get a synthetic for free. + return (range_check_ptr=range_check_ptr) + end + + # Since we sell collateral, we round actual_sold. + # (actual_sold - 1) / actual_bought < amount_sell / amount_buy, thus + # (actual_sold - 1) * amount_buy < amount_sell * actual_bought. + %{ error_code = ids.PerpetualErrorCode.INVALID_FULFILLMENT_ASSETS_RATIO %} + assert_lt{range_check_ptr=range_check_ptr}( + (actual_sold - 1) * amount_buy, amount_sell * actual_bought) + %{ del error_code %} + return (range_check_ptr=range_check_ptr) + end + + # Selling synthetic. + let actual_sold = actual_synthetic + let actual_bought = actual_collateral + let amount_sell = limit_order.amount_synthetic + let amount_buy = amount_collateral + # Since we buy collateral, we round actual_bought. + # actual_sold / (actual_bought + 1) < amount_sell / amount_buy, thus + # actual_sold * amount_buy < amount_sell * (actual_bought + 1). + %{ error_code = ids.PerpetualErrorCode.INVALID_FULFILLMENT_ASSETS_RATIO %} + assert_lt{range_check_ptr=range_check_ptr}( + actual_sold * amount_buy, amount_sell * (actual_bought + 1)) + %{ del error_code %} + return (range_check_ptr=range_check_ptr) +end diff --git a/src/services/perpetual/cairo/output/data_availability.cairo b/src/services/perpetual/cairo/output/data_availability.cairo new file mode 100644 index 0000000..dc30bdd --- /dev/null +++ b/src/services/perpetual/cairo/output/data_availability.cairo @@ -0,0 +1,64 @@ +from services.perpetual.cairo.definitions.objects import ( + FundingIndicesInfo, funding_indices_info_serialize) +from services.perpetual.cairo.output.program_output import PerpetualOutputs +from services.perpetual.cairo.position.serialize_change import serialize_position_change +from services.perpetual.cairo.state.state import SquashedCarriedState +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.registers import get_label_location +from starkware.cairo.common.serialize import serialize_array + +# Serializes a single FundingIndicesInfo entry. +func funding_indices_info_ptr_serialize{output_ptr : felt*}( + funding_indices_ptr : FundingIndicesInfo**): + return funding_indices_info_serialize(funding_indices=[funding_indices_ptr]) +end + +# Outputs the position changes. +func output_changed_positions( + range_check_ptr, output_ptr : felt*, squashed_dict : DictAccess*, n_entries) -> ( + range_check_ptr, output_ptr : felt*): + if n_entries == 0: + return (range_check_ptr=range_check_ptr, output_ptr=output_ptr) + end + + let (range_check_ptr, output_ptr) = serialize_position_change( + range_check_ptr=range_check_ptr, output_ptr=output_ptr, dict_access=squashed_dict) + + return output_changed_positions( + range_check_ptr=range_check_ptr, + output_ptr=output_ptr, + squashed_dict=squashed_dict + DictAccess.SIZE, + n_entries=n_entries - 1) +end + +# Outputs the data required for data availability. +func output_availability_data( + range_check_ptr, output_ptr : felt*, squashed_state : SquashedCarriedState*, + perpetual_outputs_start : PerpetualOutputs*, perpetual_outputs_end : PerpetualOutputs*) -> ( + range_check_ptr, output_ptr : felt*): + alloc_locals + # Serialize the funding indices table. + let (callback_adddress) = get_label_location(label_value=funding_indices_info_ptr_serialize) + let funding_indices_table_size = ( + perpetual_outputs_end.funding_indices_table_ptr - + perpetual_outputs_start.funding_indices_table_ptr) + + with output_ptr: + serialize_array( + array=cast(perpetual_outputs_start.funding_indices_table_ptr, felt*), + n_elms=funding_indices_table_size, + elm_size=1, + callback=callback_adddress) + end + + # Serialize the position changes. + let dict_len = ( + cast(squashed_state.positions_dict_end, felt) - cast(squashed_state.positions_dict, felt)) + let (range_check_ptr, output_ptr) = output_changed_positions( + range_check_ptr=range_check_ptr, + output_ptr=output_ptr, + squashed_dict=squashed_state.positions_dict, + n_entries=dict_len / DictAccess.SIZE) + + return (range_check_ptr=range_check_ptr, output_ptr=output_ptr) +end diff --git a/src/services/perpetual/cairo/output/forced.cairo b/src/services/perpetual/cairo/output/forced.cairo new file mode 100644 index 0000000..b509c79 --- /dev/null +++ b/src/services/perpetual/cairo/output/forced.cairo @@ -0,0 +1,86 @@ +from starkware.cairo.common.registers import get_fp_and_pc +from starkware.cairo.common.serialize import serialize_word + +namespace ForcedActionType: + const FORCED_WITHDRAWAL = 0 + const FORCED_TRADE = 1 +end + +# The parameters of any forced action that are registered onchain. +# A forced action is a transaction that is registered onchain and once it is registered it must be +# fulfilled in a set amount of time. +struct ForcedAction: + member forced_type : felt + member forced_action : felt* +end + +# A forced action of withdrawal. +struct ForcedWithdrawalAction: + member public_key : felt + member position_id : felt + member amount : felt +end + +func forced_withdrawal_action_new(public_key, position_id, amount) -> ( + forced_withdrawal_action : ForcedWithdrawalAction*): + let (fp_val, pc_val) = get_fp_and_pc() + return ( + forced_withdrawal_action=cast(fp_val - 2 - ForcedWithdrawalAction.SIZE, ForcedWithdrawalAction*)) +end + +func forced_withdrawal_action_serialize{output_ptr : felt*}( + forced_withdrawal_action : ForcedWithdrawalAction*): + serialize_word(forced_withdrawal_action.public_key) + serialize_word(forced_withdrawal_action.position_id) + serialize_word(forced_withdrawal_action.amount) + return () +end + +# A forced action of trade. +struct ForcedTradeAction: + member public_key_a : felt + member public_key_b : felt + member position_id_a : felt + member position_id_b : felt + member synthetic_asset_id : felt + member amount_collateral : felt + member amount_synthetic : felt + member is_party_a_buying_synthetic : felt + member nonce : felt +end + +func forced_trade_action_new( + public_key_a, public_key_b, position_id_a, position_id_b, synthetic_asset_id, + amount_collateral, amount_synthetic, is_party_a_buying_synthetic, nonce) -> ( + forced_trade_action : ForcedTradeAction*): + let (fp_val, pc_val) = get_fp_and_pc() + return (forced_trade_action=cast(fp_val - 2 - ForcedTradeAction.SIZE, ForcedTradeAction*)) +end + +func forced_trade_action_serialize{output_ptr : felt*}(forced_trade_action : ForcedTradeAction*): + serialize_word(forced_trade_action.public_key_a) + serialize_word(forced_trade_action.public_key_b) + serialize_word(forced_trade_action.position_id_a) + serialize_word(forced_trade_action.position_id_b) + serialize_word(forced_trade_action.synthetic_asset_id) + serialize_word(forced_trade_action.amount_collateral) + serialize_word(forced_trade_action.amount_synthetic) + serialize_word(forced_trade_action.is_party_a_buying_synthetic) + serialize_word(forced_trade_action.nonce) + return () +end + +func forced_action_serialize{output_ptr : felt*}(forced_action : ForcedAction*): + let forced_type = forced_action.forced_type + serialize_word(forced_type) + if forced_type == ForcedActionType.FORCED_WITHDRAWAL: + return forced_withdrawal_action_serialize( + cast(forced_action.forced_action, ForcedWithdrawalAction*)) + end + if forced_type == ForcedActionType.FORCED_TRADE: + return forced_trade_action_serialize(cast(forced_action.forced_action, ForcedTradeAction*)) + end + + assert 1 = 0 + jmp rel 0 +end diff --git a/src/services/perpetual/cairo/output/program_input.cairo b/src/services/perpetual/cairo/output/program_input.cairo new file mode 100644 index 0000000..2a98a74 --- /dev/null +++ b/src/services/perpetual/cairo/output/program_input.cairo @@ -0,0 +1,15 @@ +from services.perpetual.cairo.definitions.general_config import GeneralConfig +from services.perpetual.cairo.oracle.oracle_price import AssetOraclePrice +from services.perpetual.cairo.state.state import CarriedState, SharedState +from services.perpetual.cairo.transactions.transaction import Transactions + +struct ProgramInput: + member general_config : GeneralConfig* + member prev_shared_state : SharedState* + member new_shared_state : SharedState* + member minimum_expiration_timestamp : felt + member txs : Transactions* + member n_signed_oracle_prices : felt + member signed_min_oracle_prices : AssetOraclePrice* + member signed_max_oracle_prices : AssetOraclePrice* +end diff --git a/src/services/perpetual/cairo/output/program_output.cairo b/src/services/perpetual/cairo/output/program_output.cairo new file mode 100644 index 0000000..e990e9c --- /dev/null +++ b/src/services/perpetual/cairo/output/program_output.cairo @@ -0,0 +1,136 @@ +from services.perpetual.cairo.definitions.general_config_hash import AssetConfigHashEntry +from services.perpetual.cairo.definitions.objects import ( + FundingIndicesInfo, funding_indices_info_serialize) +from services.perpetual.cairo.output.forced import ForcedAction, forced_action_serialize +from services.perpetual.cairo.state.state import SharedState, shared_state_serialize +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.memcpy import memcpy +from starkware.cairo.common.registers import get_fp_and_pc +from starkware.cairo.common.serialize import serialize_array, serialize_word + +# Represents an external modification to the amount of collateral in a position (Deposit or +# withdrawal). +struct Modification: + member public_key : felt + member position_id : felt + # Biased representation. biased_delta is in range [0, 2**65), where 2**64 means 0 change. + # The effective difference is biased_delta - 2**64. + member biased_delta : felt +end + +func modification_serialize{output_ptr : felt*}(modification : Modification*): + serialize_word(modification.public_key) + serialize_word(modification.position_id) + serialize_word(modification.biased_delta) + return () +end + +func asset_config_hash_serialize{output_ptr : felt*}(asset_config_hash : AssetConfigHashEntry*): + serialize_word(asset_config_hash.asset_id) + serialize_word(asset_config_hash.config_hash) + return () +end + +# Represents the entire output of the program. +struct ProgramOutput: + member general_config_hash : felt + member n_asset_configs : felt + member asset_configs : AssetConfigHashEntry* + member prev_shared_state : SharedState* + member new_shared_state : SharedState* + member minimum_expiration_timestamp : felt + + member n_modifications : felt + member modifications : Modification* + member n_forced_actions : felt + member forced_actions : ForcedAction* + member n_conditions : felt + member conditions : felt* +end + +func program_output_new( + general_config_hash, n_asset_configs, asset_configs : AssetConfigHashEntry*, + prev_shared_state : SharedState*, new_shared_state : SharedState*, + minimum_expiration_timestamp, n_modifications, modifications : Modification*, + n_forced_actions, forced_actions : ForcedAction*, n_conditions, conditions : felt*) -> ( + program_output : ProgramOutput*): + let (fp_val, pc_val) = get_fp_and_pc() + return (program_output=cast(fp_val - 2 - ProgramOutput.SIZE, ProgramOutput*)) +end + +# Represents the outputs that were accumulated during the execution of the batch. +struct PerpetualOutputs: + member modifications_ptr : Modification* + member forced_actions_ptr : ForcedAction* + member conditions_ptr : felt* + # A log of all the funding indices. When serializing a position change, The funding + # timestamp is serialized instead of the funding indices and it can be looked up in this log. + member funding_indices_table_ptr : FundingIndicesInfo** +end + +func perpetual_outputs_new( + modifications_ptr : Modification*, forced_actions_ptr : ForcedAction*, + conditions_ptr : felt*, funding_indices_table_ptr : FundingIndicesInfo**) -> ( + outputs : PerpetualOutputs*): + let (fp_val, pc_val) = get_fp_and_pc() + return (outputs=cast(fp_val - 2 - PerpetualOutputs.SIZE, PerpetualOutputs*)) +end + +func perpetual_outputs_empty() -> (outputs : PerpetualOutputs*): + let (modifications_ptr : Modification*) = alloc() + let (forced_actions_ptr : ForcedAction*) = alloc() + let (conditions_ptr : felt*) = alloc() + let (funding_indices_table_ptr : FundingIndicesInfo**) = alloc() + return perpetual_outputs_new( + modifications_ptr=modifications_ptr, + forced_actions_ptr=forced_actions_ptr, + conditions_ptr=conditions_ptr, + funding_indices_table_ptr=funding_indices_table_ptr) +end + +func program_output_serialize{output_ptr : felt*}(program_output : ProgramOutput*): + alloc_locals + + let (_, __pc__) = get_fp_and_pc() + + ret_pc_label: + local __pc__ = __pc__ + + serialize_word(program_output.general_config_hash) + serialize_array( + array=program_output.asset_configs, + n_elms=program_output.n_asset_configs, + elm_size=AssetConfigHashEntry.SIZE, + callback=asset_config_hash_serialize + __pc__ - ret_pc_label) + shared_state_serialize(program_output.prev_shared_state) + shared_state_serialize(program_output.new_shared_state) + + serialize_word(program_output.minimum_expiration_timestamp) + + # Modifications. + serialize_array( + array=program_output.modifications, + n_elms=program_output.n_modifications, + elm_size=Modification.SIZE, + callback=modification_serialize + __pc__ - ret_pc_label) + # Forced actions. + # Save a cell for total size of forced actions. + local forced_actions_size_output_ptr : felt* = output_ptr + let output_ptr = output_ptr + 1 + serialize_array( + array=program_output.forced_actions, + n_elms=program_output.n_forced_actions, + elm_size=ForcedAction.SIZE, + callback=forced_action_serialize + __pc__ - ret_pc_label) + # output_ptr - forced_actions_size_output_ptr is the size of written data including + # forced_actions_size and n_forced_actions. + let data_size = cast(output_ptr, felt) - cast(forced_actions_size_output_ptr, felt) - 2 + serialize_word{output_ptr=forced_actions_size_output_ptr}(data_size) + + # Conditions. + serialize_word(program_output.n_conditions) + local output_ptr : felt* = output_ptr + memcpy(dst=output_ptr, src=program_output.conditions, len=program_output.n_conditions) + let output_ptr = output_ptr + program_output.n_conditions + return () +end diff --git a/src/services/perpetual/cairo/position/add_asset.cairo b/src/services/perpetual/cairo/position/add_asset.cairo new file mode 100644 index 0000000..ce9cd58 --- /dev/null +++ b/src/services/perpetual/cairo/position/add_asset.cairo @@ -0,0 +1,177 @@ +from services.perpetual.cairo.definitions.constants import POSITION_MAX_SUPPORTED_N_ASSETS +from services.perpetual.cairo.definitions.objects import FundingIndex, FundingIndicesInfo +from services.perpetual.cairo.definitions.perpetual_error_code import PerpetualErrorCode +from services.perpetual.cairo.position.position import ( + Position, PositionAsset, check_request_public_key, check_valid_balance, + create_maybe_empty_position) +from starkware.cairo.common.find_element import search_sorted, search_sorted_lower +from starkware.cairo.common.math import assert_not_equal +from starkware.cairo.common.memcpy import memcpy + +# Fetches the balance and cached funding index of a position asset if found. +# Otherwise, returns 0 balance and fetches funding index from global_funding_indices. +func get_old_asset( + range_check_ptr, asset_ptr : PositionAsset*, asset_found, + global_funding_indices : FundingIndicesInfo*, asset_id) -> ( + range_check_ptr, balance, funding_index, return_code): + if asset_found != 0: + # Asset found. + return ( + range_check_ptr=range_check_ptr, + balance=asset_ptr.balance, + funding_index=asset_ptr.cached_funding_index, + return_code=PerpetualErrorCode.SUCCESS) + end + + # Previous asset missing => initial balance is zero. + # Find funding index. + let (found_funding_index : FundingIndex*, success) = search_sorted{ + range_check_ptr=range_check_ptr}( + array_ptr=global_funding_indices.funding_indices, + elm_size=FundingIndex.SIZE, + n_elms=global_funding_indices.n_funding_indices, + key=asset_id) + if success == 0: + return ( + range_check_ptr=range_check_ptr, + balance=0, + funding_index=0, + return_code=PerpetualErrorCode.ERROR) + end + + return ( + range_check_ptr=range_check_ptr, + balance=0, + funding_index=found_funding_index.funding_index, + return_code=PerpetualErrorCode.SUCCESS) +end + +# Builds the result position assets array after adding delta to the original assets array at +# asset_id. +func add_asset_inner( + range_check_ptr, n_assets, assets_ptr : PositionAsset*, res_ptr : PositionAsset*, + global_funding_indices : FundingIndicesInfo*, asset_id, delta) -> ( + range_check_ptr, end_ptr : PositionAsset*, return_code): + alloc_locals + # Split original assets array, around asset_id. + let (left_end_ptr : PositionAsset*) = search_sorted_lower{range_check_ptr=range_check_ptr}( + array_ptr=assets_ptr, elm_size=PositionAsset.SIZE, n_elms=n_assets, key=asset_id) + # left_end_ptr is the pointer before current asset. + local left_end_ptr : PositionAsset* = left_end_ptr + let (right_start_ptr : PositionAsset*) = search_sorted_lower{range_check_ptr=range_check_ptr}( + array_ptr=assets_ptr, elm_size=PositionAsset.SIZE, n_elms=n_assets, key=asset_id + 1) + local range_check_ptr = range_check_ptr + # right_start_ptr is the pointer after current asset. + local right_start_ptr : PositionAsset* = right_start_ptr + + # Auxiliary variables. + local assets_end_ptr : PositionAsset* = assets_ptr + n_assets * PositionAsset.SIZE + local left_size = left_end_ptr - assets_ptr + local right_size = assets_end_ptr - right_start_ptr + local res_left_end : PositionAsset* = res_ptr + left_size + + # Compute current balance and funding index. + let (range_check_ptr, balance, local funding_index, return_code) = get_old_asset( + range_check_ptr=range_check_ptr, + asset_ptr=left_end_ptr, + asset_found=right_start_ptr - left_end_ptr, + global_funding_indices=global_funding_indices, + asset_id=asset_id) + + if return_code != PerpetualErrorCode.SUCCESS: + return (range_check_ptr=range_check_ptr, end_ptr=res_ptr, return_code=return_code) + end + + # Check new balance validity. + local new_balance = balance + delta + let (local range_check_ptr, return_code) = check_valid_balance( + range_check_ptr=range_check_ptr, balance=new_balance) + if return_code != PerpetualErrorCode.SUCCESS: + return (range_check_ptr=range_check_ptr, end_ptr=res_ptr, return_code=return_code) + end + + # Copy left portion. + memcpy(dst=res_ptr, src=assets_ptr, len=left_size) + + # Don't write new asset if new balance is 0. + if new_balance == 0: + # Copy right portion. + memcpy(dst=res_left_end, src=right_start_ptr, len=right_size) + return ( + range_check_ptr=range_check_ptr, + end_ptr=res_left_end + right_size, + return_code=PerpetualErrorCode.SUCCESS) + end + + # Write new asset. + assert res_left_end.asset_id = asset_id + assert res_left_end.balance = new_balance + assert res_left_end.cached_funding_index = funding_index + + # Copy right portion. + let res_right_start = res_left_end + PositionAsset.SIZE + memcpy(dst=res_right_start, src=right_start_ptr, len=right_size) + return ( + range_check_ptr=range_check_ptr, + end_ptr=res_right_start + right_size, + return_code=PerpetualErrorCode.SUCCESS) +end + +# Changes an asset balance of a position by delta. delta may be negative. Handles non existing and +# empty assets correctly. +func position_add_asset( + range_check_ptr, position : Position*, global_funding_indices : FundingIndicesInfo*, + asset_id, delta, public_key) -> (range_check_ptr, position : Position*, return_code): + # Allow invalid asset_id when delta == 0. + if delta == 0: + return ( + range_check_ptr=range_check_ptr, + position=position, + return_code=PerpetualErrorCode.SUCCESS) + end + + local res_assets_ptr : PositionAsset* + %{ ids.res_assets_ptr = segments.add() %} + alloc_locals + + # Verify public_key. + let (return_code) = check_request_public_key( + position_public_key=position.public_key, request_public_key=public_key) + if return_code != PerpetualErrorCode.SUCCESS: + return (range_check_ptr=range_check_ptr, position=position, return_code=return_code) + end + + # Call add_asset_inner. + let (local range_check_ptr, end_ptr : PositionAsset*, return_code) = add_asset_inner( + range_check_ptr=range_check_ptr, + n_assets=position.n_assets, + assets_ptr=position.assets_ptr, + res_ptr=res_assets_ptr, + global_funding_indices=global_funding_indices, + asset_id=asset_id, + delta=delta) + if return_code != PerpetualErrorCode.SUCCESS: + return (range_check_ptr=range_check_ptr, position=position, return_code=return_code) + end + + tempvar res_n_assets = (end_ptr - res_assets_ptr) / PositionAsset.SIZE + # A single position may not contain more than POSITION_MAX_SUPPORTED_N_ASSETS assets. We may + # assert that (res_n_assets != POSITION_MAX_SUPPORTED_N_ASSETS + 1) instead of + # (res_n_assets <= POSITION_MAX_SUPPORTED_N_ASSETS) since each transaction adds at most one + # asset to a position and therefore checking for inequality is equivalent to comparing. + if res_n_assets == POSITION_MAX_SUPPORTED_N_ASSETS + 1: + return ( + range_check_ptr=range_check_ptr, + position=position, + return_code=PerpetualErrorCode.TOO_MANY_SYNTHETIC_ASSETS_IN_POSITION) + end + + let (position : Position*) = create_maybe_empty_position( + public_key=public_key, + collateral_balance=position.collateral_balance, + n_assets=res_n_assets, + assets_ptr=res_assets_ptr, + funding_timestamp=position.funding_timestamp) + return ( + range_check_ptr=range_check_ptr, position=position, return_code=PerpetualErrorCode.SUCCESS) +end diff --git a/src/services/perpetual/cairo/position/check_smaller_holdings.cairo b/src/services/perpetual/cairo/position/check_smaller_holdings.cairo new file mode 100644 index 0000000..4aadd75 --- /dev/null +++ b/src/services/perpetual/cairo/position/check_smaller_holdings.cairo @@ -0,0 +1,92 @@ +from services.perpetual.cairo.definitions.constants import BALANCE_LOWER_BOUND, BALANCE_UPPER_BOUND +from services.perpetual.cairo.definitions.perpetual_error_code import PerpetualErrorCode +from services.perpetual.cairo.position.position import Position, PositionAsset +from starkware.cairo.common.math_cmp import is_le, is_nn + +# Inner function for check_smaller_in_synthetic_holdings_inner. Checks a single asset and then +# recursively checks the rest. +func check_smaller_in_synthetic_holdings_inner( + range_check_ptr, n_updated_position_assets, updated_position_assets : PositionAsset*, + n_initial_position_assets, initial_position_assets : PositionAsset*) -> ( + range_check_ptr, return_code): + if n_updated_position_assets == 0: + # At this point, we've either passed on all initial assets and updated assets, or there are + # remaining assets in the initial position that have been removed from the updated position, + # which means they are valid because their updated balance is 0. + return (range_check_ptr=range_check_ptr, return_code=PerpetualErrorCode.SUCCESS) + end + if n_initial_position_assets == 0: + # There is a new synthetic asset. Therefore the position is not smaller in synthetic + # holdings. + return ( + range_check_ptr=range_check_ptr, + return_code=PerpetualErrorCode.ILLEGAL_POSITION_TRANSITION_ENLARGING_SYNTHETIC_HOLDINGS) + end + + alloc_locals + local updated_balance = updated_position_assets.balance + local initial_balance = initial_position_assets.balance + + if updated_position_assets.asset_id != initial_position_assets.asset_id: + # Because the asset ids are sorted, we can assume that the initial position's asset id + # doesn't exist in the updated position. (If that isn't true then we will eventually have + # n_initial_position_assets == 0). + # This means that the initial position's asset has updated balance 0 and we can skip it. + return check_smaller_in_synthetic_holdings_inner( + range_check_ptr=range_check_ptr, + n_updated_position_assets=n_updated_position_assets, + updated_position_assets=updated_position_assets, + n_initial_position_assets=n_initial_position_assets - 1, + initial_position_assets=initial_position_assets + PositionAsset.SIZE) + end + + # Check that updated_balance and initial_balance have the same sign. + let (success) = is_nn{range_check_ptr=range_check_ptr}(updated_balance * initial_balance) + if success == 0: + return ( + range_check_ptr=range_check_ptr, + return_code=PerpetualErrorCode.ILLEGAL_POSITION_TRANSITION_ENLARGING_SYNTHETIC_HOLDINGS) + end + + # Check that abs(updated_balance) <= abs(initial_balance) using + # (updated_balance^2) <= (initial_balance^2). + # See the assumption in check_smaller_in_synthetic_holdings. + let (success) = is_le{range_check_ptr=range_check_ptr}( + updated_balance * updated_balance, initial_balance * initial_balance) + if success == 0: + return ( + range_check_ptr=range_check_ptr, + return_code=PerpetualErrorCode.ILLEGAL_POSITION_TRANSITION_ENLARGING_SYNTHETIC_HOLDINGS) + end + + return check_smaller_in_synthetic_holdings_inner( + range_check_ptr=range_check_ptr, + n_updated_position_assets=n_updated_position_assets - 1, + updated_position_assets=updated_position_assets + PositionAsset.SIZE, + n_initial_position_assets=n_initial_position_assets - 1, + initial_position_assets=initial_position_assets + PositionAsset.SIZE) +end + +# Checks that updated_position is as safe as the initial position. +# This means that the balance of each asset did not change sign, and its absolute value +# did not increase. +# Returns 1 if the check passes, 0 otherwise. +# +# Assumption: +# All the asset balances are in the range [BALANCE_LOWER_BOUND, BALANCE_UPPER_BOUND). +# The position's assets are sorted by asset id. +# max(BALANCE_LOWER_BOUND**2, (BALANCE_UPPER_BOUND - 1)**2) < range_check_builtin.bound. +func check_smaller_in_synthetic_holdings( + range_check_ptr, updated_position : Position*, initial_position : Position*) -> ( + range_check_ptr, return_code): + %{ + assert max( + ids.BALANCE_LOWER_BOUND**2, (ids.BALANCE_UPPER_BOUND - 1)**2) < range_check_builtin.bound + %} + return check_smaller_in_synthetic_holdings_inner( + range_check_ptr=range_check_ptr, + n_updated_position_assets=updated_position.n_assets, + updated_position_assets=updated_position.assets_ptr, + n_initial_position_assets=initial_position.n_assets, + initial_position_assets=initial_position.assets_ptr) +end diff --git a/src/services/perpetual/cairo/position/funding.cairo b/src/services/perpetual/cairo/position/funding.cairo new file mode 100644 index 0000000..b89ba3e --- /dev/null +++ b/src/services/perpetual/cairo/position/funding.cairo @@ -0,0 +1,118 @@ +from services.perpetual.cairo.definitions.constants import ( + BALANCE_LOWER_BOUND, BALANCE_UPPER_BOUND, FXP_32_ONE) +from services.perpetual.cairo.definitions.objects import FundingIndex, FundingIndicesInfo +from services.perpetual.cairo.position.position import Position, PositionAsset, position_new +from starkware.cairo.common.find_element import find_element +from starkware.cairo.common.math import assert_nn_le, signed_div_rem +from starkware.cairo.common.registers import get_fp_and_pc + +# Computes the total_funding for a given position and updates the cached funding indices. +# The funding per asset is computed as: +# (global_funding_index - cached_funding_index) * balance. +# +# Arguments: +# range_check_ptr - range check builtin pointer. +# assets_before - a pointer to PositionAsset array. +# global_funding_indices - a pointer to a FundingIndicesInfo. +# current_collateral_fxp - Current collateral as signed (.32) fixed point. +# assets_after - a pointer to an output array, which will be filled with +# the same assets as assets_before but with an updated cached_funding_index. +# +# Returns: +# range_check_ptr - new range check builtin pointer. +# collateral_fxp - The colleteral after the funding was applied as signed (.32) fixed point. +# +# Assumption: current_collateral_fxp does not overflow, it is a sum of 95 bit values. +# Prover assumption: The assets in assets_before are a subset of the assets in +# global_funding_indices. +func apply_funding_inner( + range_check_ptr, assets_before : PositionAsset*, n_assets, + global_funding_indices : FundingIndicesInfo*, current_collateral_fxp, + assets_after : PositionAsset*) -> (range_check_ptr, collateral_fxp): + jmp body if n_assets != 0 + + # Return. + return (range_check_ptr=range_check_ptr, collateral_fxp=current_collateral_fxp) + + body: + alloc_locals + let current_asset : PositionAsset* = assets_before + + local asset_id = current_asset.asset_id + + # The key must be at offset 0. + static_assert FundingIndex.asset_id == 0 + let (funding_index : FundingIndex*) = find_element{range_check_ptr=range_check_ptr}( + array_ptr=global_funding_indices.funding_indices, + elm_size=FundingIndex.SIZE, + n_elms=global_funding_indices.n_funding_indices, + key=asset_id) + + tempvar global_funding_index = funding_index.funding_index + + # Compute fixed point fxp_delta_funding := delta_funding_index * balance. + tempvar balance = current_asset.balance + tempvar delta_funding_index = global_funding_index - current_asset.cached_funding_index + tempvar fxp_delta_funding = delta_funding_index * balance + + # Copy asset to assets_after with an updated cached_funding_index. + let asset_after : PositionAsset* = assets_after + asset_after.asset_id = asset_id + asset_after.cached_funding_index = global_funding_index + asset_after.balance = balance + + # Call recursively. + return apply_funding_inner( + range_check_ptr=range_check_ptr, + assets_before=assets_before + PositionAsset.SIZE, + n_assets=n_assets - 1, + global_funding_indices=global_funding_indices, + current_collateral_fxp=current_collateral_fxp - fxp_delta_funding, + assets_after=assets_after + PositionAsset.SIZE) +end + +# Change the cached funding indices in the position into the updated funding indices and update the +# collateral balance according to the funding diff. +func position_apply_funding( + range_check_ptr, position : Position*, global_funding_indices : FundingIndicesInfo*) -> ( + range_check_ptr, position : Position*): + local new_assets_ptr : PositionAsset* + alloc_locals + + %{ + ids.new_assets_ptr = new_assets_ptr = segments.add() + segments.finalize( + new_assets_ptr.segment_index, + ids.position.n_assets * ids.PositionAsset.SIZE) + %} + + let (range_check_ptr, collateral_fxp) = apply_funding_inner( + range_check_ptr=range_check_ptr, + assets_before=position.assets_ptr, + n_assets=position.n_assets, + global_funding_indices=global_funding_indices, + current_collateral_fxp=position.collateral_balance * FXP_32_ONE, + assets_after=new_assets_ptr) + + # Convert collateral_fxp from fixed points to an integer and range check that + # BALANCE_LOWER_BOUND <= collateral_balance < BALANCE_UPPER_BOUND. + static_assert BALANCE_LOWER_BOUND == -BALANCE_UPPER_BOUND + # The collateral changes due to funding over all positions always sum up to 0 + # (Assuming no rounding). Therefore the collateral delta is rounded down to make sure funding + # does not make collateral out of thin air. + # For example if we have 3 users a, b and c and the computed funding is as follows: + # a = -0.5, b = -0.5, c = 1, we round the funding down to a = -1, b = -1 and c = 1 and therefore + # we lose 1 collateral in the system from funding + # (If instead we rounded up we would've created 1). + let (new_collateral_balance, _) = signed_div_rem{range_check_ptr=range_check_ptr}( + value=collateral_fxp, div=FXP_32_ONE, bound=BALANCE_UPPER_BOUND) + + let (updated_position) = position_new( + public_key=position.public_key, + collateral_balance=new_collateral_balance, + n_assets=position.n_assets, + assets_ptr=new_assets_ptr, + funding_timestamp=global_funding_indices.funding_timestamp) + + return (range_check_ptr=range_check_ptr, position=updated_position) +end diff --git a/src/services/perpetual/cairo/position/hash.cairo b/src/services/perpetual/cairo/position/hash.cairo new file mode 100644 index 0000000..4366038 --- /dev/null +++ b/src/services/perpetual/cairo/position/hash.cairo @@ -0,0 +1,114 @@ +from services.perpetual.cairo.definitions.constants import ( + BALANCE_LOWER_BOUND, BALANCE_UPPER_BOUND, FUNDING_INDEX_LOWER_BOUND, FUNDING_INDEX_UPPER_BOUND, + N_ASSETS_UPPER_BOUND) +from services.perpetual.cairo.position.position import Position, PositionAsset +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.hash import hash2 + +# Inner tail recursive function for position_hash. +# +# Assumptions: +# assets.asset_id < ASSET_ID_UPPER_BOUND (Enforced by the solidity contract). +# FUNDING_INDEX_LOWER_BOUND <= assets.cached_funding_index < FUNDING_INDEX_UPPER_BOUND +# BALANCE_LOWER_BOUND <= assets.balance < BALANCE_UPPER_BOUND. +# ASSET_ID_UPPER_BOUND * (FUNDING_INDEX_UPPER_BOUND - FUNDING_INDEX_LOWER_BOUND) * ( +# BALANCE_UPPER_BOUND - BALANCE_LOWER_BOUND) < PRIME +func position_hash_assets{pedersen_ptr : HashBuiltin*}( + assets : PositionAsset*, n_assets, current_hash) -> (assets_hash): + if n_assets == 0: + return (assets_hash=current_hash) + end + + let asset_packed = assets.asset_id + let asset_packed = asset_packed * ( + FUNDING_INDEX_UPPER_BOUND - FUNDING_INDEX_LOWER_BOUND) + ( + assets.cached_funding_index - FUNDING_INDEX_LOWER_BOUND) + let asset_packed = asset_packed * (BALANCE_UPPER_BOUND - BALANCE_LOWER_BOUND) + + (assets.balance - BALANCE_LOWER_BOUND) + + let (result) = hash2{hash_ptr=pedersen_ptr}(x=current_hash, y=asset_packed) + + # Call recursively. + return position_hash_assets( + assets=assets + PositionAsset.SIZE, n_assets=n_assets - 1, current_hash=result) +end + +# Computes the hash of the position. +# +# Arguments: +# pedersen_ptr - a pedersen builtin pointer. +# position - a pointer to Position. +# +# Returns: +# pedersen_ptr - new pedersen builtin pointer. +# position_hash - the hash of the position +# +# Assumptions: +# The assets are sorted by asset_id. +# The position_hash_assets assumptions hold for all the assets. +func position_hash{pedersen_ptr : HashBuiltin*}(position : Position*) -> (position_hash): + let (assets_hash) = position_hash_assets( + assets=position.assets_ptr, n_assets=position.n_assets, current_hash=0) + + # Hash the assests_hash with the public key. + let (result) = hash2{hash_ptr=pedersen_ptr}(x=assets_hash, y=position.public_key) + + # Hash the above with the biased collateral balance and the number of assets. + let (result) = hash2{hash_ptr=pedersen_ptr}( + x=result, + y=(position.collateral_balance - BALANCE_LOWER_BOUND) * N_ASSETS_UPPER_BOUND + + position.n_assets) + + return (position_hash=result) +end + +func hash_position_updates_inner{pedersen_ptr : HashBuiltin*}( + update_ptr : DictAccess*, n_updates, hashed_updates_ptr : DictAccess*) -> (): + if n_updates == 0: + return () + end + + assert hashed_updates_ptr.key = update_ptr.key + + # Previous position hash. + let prev_position = cast(update_ptr.prev_value, Position*) + let (hashed_position) = position_hash(position=prev_position) + assert hashed_updates_ptr.prev_value = hashed_position + + # Touch funding_timestamp. + tempvar funding_timestamp = prev_position.funding_timestamp + + # New position hash. + %{ memory[ap] = 1 if ids.update_ptr.prev_value != ids.update_ptr.new_value else 0 %} + jmp not_equal if [ap] != 0; ap++ + + equal: + # Same as previous. Do not recompute hash. + assert update_ptr.prev_value = update_ptr.new_value + assert hashed_updates_ptr.new_value = hashed_position + return hash_position_updates_inner( + update_ptr=update_ptr + DictAccess.SIZE, + n_updates=n_updates - 1, + hashed_updates_ptr=hashed_updates_ptr + DictAccess.SIZE) + + not_equal: + # Recompute hash. + let (hashed_position) = position_hash(position=cast(update_ptr.new_value, Position*)) + assert hashed_updates_ptr.new_value = hashed_position + return hash_position_updates_inner( + update_ptr=update_ptr + DictAccess.SIZE, + n_updates=n_updates - 1, + hashed_updates_ptr=hashed_updates_ptr + DictAccess.SIZE) +end + +# Converts a dict of positions into a dict of position hashes. +func hash_position_updates{pedersen_ptr : HashBuiltin*}(update_ptr : DictAccess*, n_updates) -> ( + hashed_updates_ptr : DictAccess*): + local hashed_updates_ptr : DictAccess* + %{ ids.hashed_updates_ptr = segments.add() %} + alloc_locals + hash_position_updates_inner( + update_ptr=update_ptr, n_updates=n_updates, hashed_updates_ptr=hashed_updates_ptr) + return (hashed_updates_ptr=hashed_updates_ptr) +end diff --git a/src/services/perpetual/cairo/position/position.cairo b/src/services/perpetual/cairo/position/position.cairo new file mode 100644 index 0000000..33abf2e --- /dev/null +++ b/src/services/perpetual/cairo/position/position.cairo @@ -0,0 +1,128 @@ +from services.perpetual.cairo.definitions.constants import BALANCE_LOWER_BOUND, BALANCE_UPPER_BOUND +from services.perpetual.cairo.definitions.perpetual_error_code import PerpetualErrorCode +from starkware.cairo.common.find_element import search_sorted +from starkware.cairo.common.math_cmp import is_nn_le +from starkware.cairo.common.registers import get_fp_and_pc + +# Represents a specific asset in a user position. +struct PositionAsset: + member asset_id : felt + member balance : felt + # A snapshot of the funding index at the last time that funding was applied (fxp 32.32). + member cached_funding_index : felt +end + +# A user position. +struct Position: + member public_key : felt + member collateral_balance : felt + member n_assets : felt + member assets_ptr : PositionAsset* + # funding_timestamp is an auxiliary field that keeps the funding timestamp of funded positions. + # The invariant is that every position we change must have the correct funding_timestamp. + # Note however that it is not a part of the position hash, and thus, we cannot trust this value + # for positions not created during the current run (e.g. previous position in the state). + member funding_timestamp : felt +end + +func position_new( + public_key, collateral_balance, n_assets, assets_ptr : PositionAsset*, + funding_timestamp) -> (position : Position*): + let (fp_val, pc_val) = get_fp_and_pc() + return (position=cast(fp_val - 2 - Position.SIZE, Position*)) +end + +# Creates a position with given arguments. +# If the position is empty (collateral_balance == n_assets == 0) the public_key is ignored +# and an empty position is returned. +# The public_key must be non-zero. +func create_maybe_empty_position( + public_key, collateral_balance, n_assets, assets_ptr : PositionAsset*, + funding_timestamp) -> (position : Position*): + jmp body if public_key != 0 + # If public_key == 0 add an unsatisfiable requirement. + public_key = 1 + + body: + jmp assign_position if collateral_balance != 0 + jmp assign_position if n_assets != 0 + return position_new(0, 0, 0, cast(0, PositionAsset*), 0) + + assign_position: + let (fp_val, _) = get_fp_and_pc() + return (position=cast(fp_val - 2 - Position.SIZE, Position*)) +end + +# Checks that the public key supplied in a request to change the position is valid. +# The public key is valid if it matches the position's public key or if the position is empty +# (public key is zero). +# The supplied key may not be zero. +# Return 0 if the check passed, otherwise returns an error code that describes the failure. +func check_request_public_key(position_public_key, request_public_key) -> (return_code): + if request_public_key == 0: + # Invalid request_public_key. + return (return_code=PerpetualErrorCode.INVALID_PUBLIC_KEY) + end + if position_public_key == 0: + # Initial position is empty. + return (return_code=PerpetualErrorCode.SUCCESS) + end + if position_public_key == request_public_key: + # Matching keys. + return (return_code=PerpetualErrorCode.SUCCESS) + end + # Mismatching keys. + return (return_code=PerpetualErrorCode.INVALID_PUBLIC_KEY) +end + +# Checks that value is in the range [BALANCE_LOWER_BOUND, BALANCE_UPPER_BOUND) +func check_valid_balance(range_check_ptr, balance) -> (range_check_ptr, return_code): + let (success) = is_nn_le{range_check_ptr=range_check_ptr}( + balance - BALANCE_LOWER_BOUND, BALANCE_UPPER_BOUND - BALANCE_LOWER_BOUND - 1) + if success == 0: + return ( + range_check_ptr=range_check_ptr, return_code=PerpetualErrorCode.OUT_OF_RANGE_BALANCE) + end + return (range_check_ptr=range_check_ptr, return_code=PerpetualErrorCode.SUCCESS) +end + +# Changes the collateral balance of the position by delta. delta may be negative. +func position_add_collateral(range_check_ptr, position : Position*, delta, public_key) -> ( + range_check_ptr, position : Position*, return_code): + alloc_locals + # Verify public_key. + let (return_code) = check_request_public_key( + position_public_key=position.public_key, request_public_key=public_key) + if return_code != PerpetualErrorCode.SUCCESS: + return (range_check_ptr=range_check_ptr, position=position, return_code=return_code) + end + + let (local final_position : Position*) = create_maybe_empty_position( + public_key=public_key, + collateral_balance=position.collateral_balance + delta, + n_assets=position.n_assets, + assets_ptr=position.assets_ptr, + funding_timestamp=position.funding_timestamp) + + let (range_check_ptr, return_code) = check_valid_balance( + range_check_ptr=range_check_ptr, balance=final_position.collateral_balance) + + return (range_check_ptr=range_check_ptr, position=final_position, return_code=return_code) +end + +# Gets the balance of a specific asset in the position. +func position_get_asset_balance(range_check_ptr, position : Position*, asset_id) -> ( + range_check_ptr, balance): + let (position_asset_ptr : PositionAsset*, success) = search_sorted{ + range_check_ptr=range_check_ptr}( + array_ptr=position.assets_ptr, + elm_size=PositionAsset.SIZE, + n_elms=position.n_assets, + key=asset_id) + if success == 0: + # Asset is not in the position. Therefore the balance of that asset is 0 in said position. + return (range_check_ptr=range_check_ptr, balance=0) + else: + return (range_check_ptr=range_check_ptr, balance=position_asset_ptr.balance) + end +end diff --git a/src/services/perpetual/cairo/position/serialize_change.cairo b/src/services/perpetual/cairo/position/serialize_change.cairo new file mode 100644 index 0000000..c342fe9 --- /dev/null +++ b/src/services/perpetual/cairo/position/serialize_change.cairo @@ -0,0 +1,119 @@ +from services.perpetual.cairo.definitions.constants import ( + ASSET_ID_UPPER_BOUND, BALANCE_LOWER_BOUND, BALANCE_UPPER_BOUND) +from services.perpetual.cairo.position.position import Position, PositionAsset +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.math_cmp import is_le, is_nn +from starkware.cairo.common.serialize import serialize_word + +# Serializes a position asset for the on-chain data availability. +# +# Assumptions: +# asset_id < ASSET_ID_UPPER_BOUND. +# BALANCE_LOWER_BOUND <= assets.balance < BALANCE_UPPER_BOUND. +# ASSET_ID_UPPER_BOUND * (BALANCE_UPPER_BOUND - BALANCE_LOWER_BOUND) < PRIME. +func serialize_asset{output_ptr : felt*}(asset_id, balance): + serialize_word( + asset_id * (BALANCE_UPPER_BOUND - BALANCE_LOWER_BOUND) + (balance - BALANCE_LOWER_BOUND)) + return () +end + +# return assets_ptr.asset_id if n_asset != 0 and ASSET_ID_UPPER_BOUND otherwise. +func get_asset_id_or_bound(n_asset, assets_ptr : PositionAsset*) -> (asset_id): + if n_asset != 0: + return (asset_id=assets_ptr.asset_id) + else: + return (asset_id=ASSET_ID_UPPER_BOUND) + end +end + +# Inner function for serialize_position_change. +# Serializes the changes of the position assets. +func serialize_position_change_inner( + range_check_ptr, output_ptr : felt*, n_prev_position_assets, + prev_position_assets : PositionAsset*, n_new_position_assets, + new_position_assets : PositionAsset*) -> (range_check_ptr, output_ptr : felt*): + let (prev_asset_id) = get_asset_id_or_bound(n_prev_position_assets, prev_position_assets) + let (new_asset_id) = get_asset_id_or_bound(n_new_position_assets, new_position_assets) + + if prev_asset_id == new_asset_id: + # Both PositionAsset arrays are empty, we are done. + if prev_asset_id == ASSET_ID_UPPER_BOUND: + return (range_check_ptr=range_check_ptr, output_ptr=output_ptr) + end + + if new_position_assets.balance != prev_position_assets.balance: + with output_ptr: + serialize_asset(asset_id=new_asset_id, balance=new_position_assets.balance) + end + else: + tempvar output_ptr = output_ptr + end + + return serialize_position_change_inner( + range_check_ptr=range_check_ptr, + output_ptr=output_ptr, + n_prev_position_assets=n_prev_position_assets - 1, + prev_position_assets=prev_position_assets + PositionAsset.SIZE, + n_new_position_assets=n_new_position_assets - 1, + new_position_assets=new_position_assets + PositionAsset.SIZE) + end + + let (asset_was_deleted) = is_le{range_check_ptr=range_check_ptr}(prev_asset_id, new_asset_id) + if asset_was_deleted != 0: + with output_ptr: + serialize_asset(asset_id=prev_position_assets.asset_id, balance=0) + end + + return serialize_position_change_inner( + range_check_ptr=range_check_ptr, + output_ptr=output_ptr, + n_prev_position_assets=n_prev_position_assets - 1, + prev_position_assets=prev_position_assets + PositionAsset.SIZE, + n_new_position_assets=n_new_position_assets, + new_position_assets=new_position_assets) + end + + # Asset was added. + with output_ptr: + serialize_asset(asset_id=new_position_assets.asset_id, balance=new_position_assets.balance) + end + + return serialize_position_change_inner( + range_check_ptr=range_check_ptr, + output_ptr=output_ptr, + n_prev_position_assets=n_prev_position_assets, + prev_position_assets=prev_position_assets, + n_new_position_assets=n_new_position_assets - 1, + new_position_assets=new_position_assets + PositionAsset.SIZE) +end + +# Outputs the changes between the positions in dict_access. +func serialize_position_change(range_check_ptr, output_ptr : felt*, dict_access : DictAccess*) -> ( + range_check_ptr, output_ptr : felt*): + alloc_locals + local output_start_ptr : felt* = output_ptr + tempvar prev_position = cast(dict_access.prev_value, Position*) + tempvar new_position = cast(dict_access.new_value, Position*) + + # Leaving space for length. + let output_ptr = output_ptr + 1 + with output_ptr: + serialize_word(dict_access.key) + serialize_word(new_position.public_key) + serialize_word(new_position.collateral_balance - BALANCE_LOWER_BOUND) + serialize_word(new_position.funding_timestamp) + end + + let (range_check_ptr, output_ptr) = serialize_position_change_inner( + range_check_ptr=range_check_ptr, + output_ptr=output_ptr, + n_prev_position_assets=prev_position.n_assets, + prev_position_assets=prev_position.assets_ptr, + n_new_position_assets=new_position.n_assets, + new_position_assets=new_position.assets_ptr) + + let size = cast(output_ptr, felt) - cast(output_start_ptr, felt) - 1 + serialize_word{output_ptr=output_start_ptr}(size) + + return (range_check_ptr=range_check_ptr, output_ptr=output_ptr) +end diff --git a/src/services/perpetual/cairo/position/status.cairo b/src/services/perpetual/cairo/position/status.cairo new file mode 100644 index 0000000..f09ac68 --- /dev/null +++ b/src/services/perpetual/cairo/position/status.cairo @@ -0,0 +1,121 @@ +from services.perpetual.cairo.definitions.constants import ( + FXP_32_ONE, TOTAL_RISK_UPPER_BOUND, TOTAL_VALUE_LOWER_BOUND, TOTAL_VALUE_UPPER_BOUND) +from services.perpetual.cairo.definitions.general_config import GeneralConfig, SyntheticAssetInfo +from services.perpetual.cairo.definitions.objects import OraclePrice, OraclePrices +from services.perpetual.cairo.definitions.perpetual_error_code import PerpetualErrorCode +from services.perpetual.cairo.position.position import Position, PositionAsset +from starkware.cairo.common.find_element import find_element +from starkware.cairo.common.math import abs_value +from starkware.cairo.common.math_cmp import is_in_range, is_le + +# Inner tail recursive function for position_get_status. +# Computes the risk and value of the synthetic assets. +# total_value_rep is signed (.32) fixed point = sum(price * balance) for asset in assets. +# total_risk_rep is unsigned (.64) fixed point = sum(risk_factor * abs(price * balance)) +# for asset in assets. +func position_get_status_inner( + range_check_ptr, assets : PositionAsset*, n_assets, oracle_prices : OraclePrices*, + general_config : GeneralConfig*, total_value_rep, total_risk_rep) -> ( + range_check_ptr, total_value_rep, total_risk_rep): + jmp body if n_assets != 0 + return ( + range_check_ptr=range_check_ptr, + total_value_rep=total_value_rep, + total_risk_rep=total_risk_rep) + + body: + alloc_locals + let current_asset : PositionAsset* = assets + local asset_id = current_asset.asset_id + + # Compute value. + + # The key for `find_element` must be at offset 0. + static_assert OraclePrice.asset_id == 0 + let (oracle_price_elm : OraclePrice*) = find_element{range_check_ptr=range_check_ptr}( + array_ptr=oracle_prices.data, + elm_size=OraclePrice.SIZE, + n_elms=oracle_prices.len, + key=asset_id) + # Signed (96.32) fixed point. + local value_rep = oracle_price_elm.price * current_asset.balance + + # The key must be at offset 0. + static_assert SyntheticAssetInfo.asset_id == 0 + let (synthetic_info : SyntheticAssetInfo*) = find_element{range_check_ptr=range_check_ptr}( + array_ptr=general_config.synthetic_assets_info, + elm_size=SyntheticAssetInfo.SIZE, + n_elms=general_config.n_synthetic_assets_info, + key=asset_id) + local risk_factor = synthetic_info.risk_factor + + let (abs_value_rep) = abs_value{range_check_ptr=range_check_ptr}(value=value_rep) + + # value_rep is a (96.32) fixed point so risk_rep is a (128.64) fixed point. + tempvar risk_rep = abs_value_rep * risk_factor + + return position_get_status_inner( + range_check_ptr=range_check_ptr, + assets=assets + PositionAsset.SIZE, + n_assets=n_assets - 1, + oracle_prices=oracle_prices, + general_config=general_config, + total_value_rep=total_value_rep + value_rep, + total_risk_rep=total_risk_rep + risk_rep) +end + +# Computes the risk and value of a position. Returns an error code if the computed values are out of +# range. +# +# Arguments: +# range_check_ptr - range check builtin pointer. +# position - a pointer to Position. +# oracle_prices - an array of oracle prices. +# general_config - The general config of the program. +# +# Returns: +# range_check_ptr - new range check builtin pointer. +# total_value_rep is signed (.32) fixed point. +# total_risk_rep is unsigned (.64) fixed point. +func position_get_status( + range_check_ptr, position : Position*, oracle_prices : OraclePrices*, + general_config : GeneralConfig*) -> ( + range_check_ptr, total_value_rep, total_risk_rep, return_code): + alloc_locals + let (range_check_ptr, local total_value_rep, local total_risk_rep) = position_get_status_inner( + range_check_ptr=range_check_ptr, + assets=position.assets_ptr, + n_assets=position.n_assets, + oracle_prices=oracle_prices, + general_config=general_config, + total_value_rep=position.collateral_balance * FXP_32_ONE, + total_risk_rep=0) + + const TOTAL_VALUE_LOWER_BOUND_REP = TOTAL_VALUE_LOWER_BOUND * FXP_32_ONE + const TOTAL_VALUE_UPPER_BOUND_REP = TOTAL_VALUE_UPPER_BOUND * FXP_32_ONE + let (res) = is_in_range{range_check_ptr=range_check_ptr}( + total_value_rep, TOTAL_VALUE_LOWER_BOUND_REP, TOTAL_VALUE_UPPER_BOUND_REP) + if res == 0: + return ( + range_check_ptr=range_check_ptr, + total_value_rep=0, + total_risk_rep=0, + return_code=PerpetualErrorCode.OUT_OF_RANGE_TOTAL_VALUE) + end + + const TR_UPPER_BOUND_REP = TOTAL_RISK_UPPER_BOUND * FXP_32_ONE * FXP_32_ONE + let (res) = is_le{range_check_ptr=range_check_ptr}(total_risk_rep, TR_UPPER_BOUND_REP - 1) + if res == 0: + return ( + range_check_ptr=range_check_ptr, + total_value_rep=0, + total_risk_rep=0, + return_code=PerpetualErrorCode.OUT_OF_RANGE_TOTAL_RISK) + end + + return ( + range_check_ptr=range_check_ptr, + total_value_rep=total_value_rep, + total_risk_rep=total_risk_rep, + return_code=PerpetualErrorCode.SUCCESS) +end diff --git a/src/services/perpetual/cairo/position/update_position.cairo b/src/services/perpetual/cairo/position/update_position.cairo new file mode 100644 index 0000000..7b6ba51 --- /dev/null +++ b/src/services/perpetual/cairo/position/update_position.cairo @@ -0,0 +1,185 @@ +from services.perpetual.cairo.definitions.general_config import GeneralConfig +from services.perpetual.cairo.definitions.objects import ( + FundingIndex, FundingIndicesInfo, OraclePrice, OraclePrices) +from services.perpetual.cairo.definitions.perpetual_error_code import PerpetualErrorCode +from services.perpetual.cairo.position.add_asset import position_add_asset +from services.perpetual.cairo.position.funding import position_apply_funding +from services.perpetual.cairo.position.position import Position, position_add_collateral +from services.perpetual.cairo.position.validate_state_transition import check_valid_transition +from starkware.cairo.common.dict import dict_update +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.find_element import search_sorted + +# An asset id representing that no asset id is changed. +const NO_SYNTHETIC_DELTA_ASSET_ID = -1 + +# Checks whether an asset can be traded. An asset can be traded iff it has a price and a funding +# index or if it is NO_SYNTHETIC_DELTA_ASSET_ID. +func is_asset_id_tradable( + range_check_ptr, synthetic_asset_id, synthetic_delta, + global_funding_indices : FundingIndicesInfo*, oracle_prices : OraclePrices*) -> ( + range_check_ptr, return_code): + if synthetic_asset_id == NO_SYNTHETIC_DELTA_ASSET_ID: + assert synthetic_delta = 0 + return (range_check_ptr=range_check_ptr, return_code=PerpetualErrorCode.SUCCESS) + end + let (_, success) = search_sorted{range_check_ptr=range_check_ptr}( + array_ptr=oracle_prices.data, + elm_size=OraclePrice.SIZE, + n_elms=oracle_prices.len, + key=synthetic_asset_id) + if success == 0: + return ( + range_check_ptr=range_check_ptr, return_code=PerpetualErrorCode.MISSING_ORACLE_PRICE) + end + let (_, success) = search_sorted{range_check_ptr=range_check_ptr}( + array_ptr=global_funding_indices.funding_indices, + elm_size=FundingIndex.SIZE, + n_elms=global_funding_indices.n_funding_indices, + key=synthetic_asset_id) + if success == 0: + return ( + range_check_ptr=range_check_ptr, + return_code=PerpetualErrorCode.MISSING_GLOBAL_FUNDING_INDEX) + end + return (range_check_ptr=range_check_ptr, return_code=PerpetualErrorCode.SUCCESS) +end + +# Updates the position with collateral_delta and synthetic_delta and returns the updated position. +# Checks that the transition is valid. +# If the transition is invalid or a failure occured, returns the funded position and a return code +# reporting the problem. +# If the given public key is 0, skip the public key validation and validate instead that the +# position's public key isn't 0. +# Returns the initial position, the updated position and the initial position after funding was +# applied. +func update_position( + range_check_ptr, position : Position*, request_public_key, collateral_delta, + synthetic_asset_id, synthetic_delta, global_funding_indices : FundingIndicesInfo*, + oracle_prices : OraclePrices*, general_config : GeneralConfig*) -> ( + range_check_ptr, updated_position : Position*, funded_position : Position*, return_code): + alloc_locals + local final_position : Position* + let (range_check_ptr, local funded_position) = position_apply_funding( + range_check_ptr=range_check_ptr, + position=position, + global_funding_indices=global_funding_indices) + + # We need to explicitly check that the asset has a price and a funding index because otherwise, + # if the initial and updated position have a balance of 0 for that asset, it won't be caught. + let (range_check_ptr, return_code) = is_asset_id_tradable( + range_check_ptr=range_check_ptr, + synthetic_asset_id=synthetic_asset_id, + synthetic_delta=synthetic_delta, + global_funding_indices=global_funding_indices, + oracle_prices=oracle_prices) + if return_code != PerpetualErrorCode.SUCCESS: + return ( + range_check_ptr=range_check_ptr, + updated_position=funded_position, + funded_position=funded_position, + return_code=return_code) + end + + local public_key + if request_public_key == 0: + # Skip the public key validation by passing the position's public key. + public_key = funded_position.public_key + else: + public_key = request_public_key + end + let public_key = public_key + + let (range_check_ptr, updated_position, return_code) = position_add_collateral( + range_check_ptr=range_check_ptr, + position=funded_position, + delta=collateral_delta, + public_key=public_key) + if return_code != PerpetualErrorCode.SUCCESS: + final_position = position + return ( + range_check_ptr=range_check_ptr, + updated_position=funded_position, + funded_position=funded_position, + return_code=return_code) + end + + let (range_check_ptr, updated_position : Position*, return_code) = position_add_asset( + range_check_ptr=range_check_ptr, + position=updated_position, + global_funding_indices=global_funding_indices, + asset_id=synthetic_asset_id, + delta=synthetic_delta, + public_key=public_key) + if return_code != PerpetualErrorCode.SUCCESS: + return ( + range_check_ptr=range_check_ptr, + updated_position=funded_position, + funded_position=funded_position, + return_code=return_code) + end + final_position = updated_position + + let (range_check_ptr, return_code) = check_valid_transition( + range_check_ptr, final_position, funded_position, oracle_prices, general_config) + if return_code != PerpetualErrorCode.SUCCESS: + return ( + range_check_ptr=range_check_ptr, + updated_position=funded_position, + funded_position=funded_position, + return_code=return_code) + end + + return ( + range_check_ptr=range_check_ptr, + updated_position=final_position, + funded_position=funded_position, + return_code=PerpetualErrorCode.SUCCESS) +end + +# Updates the position in 'position_id' in a given dict with collateral_delta and synthetic_delta. +# Checks that initially the position is either empty or belongs to request_public_key. +# Checks that the transition is valid. +# If a failure occured, updates the position in the dict to the funded position without any changes, +# and returns a return code reporting the problem. +# If the given public key is 0, skip the public key validation. +# If synthetic delta is 0, then synthetic_asset_id can be NO_SYNTHETIC_DELTA_ASSET_ID to signal that +# no synthetic asset balance is being changed. +# Returns the updated dict, initial position, the updated position and the initial position after +# funding was applied. +func update_position_in_dict( + range_check_ptr, positions_dict : DictAccess*, position_id, request_public_key, + collateral_delta, synthetic_asset_id, synthetic_delta, + global_funding_indices : FundingIndicesInfo*, oracle_prices : OraclePrices*, + general_config : GeneralConfig*) -> ( + range_check_ptr, positions_dict : DictAccess*, funded_position : Position*, + updated_position : Position*, return_code): + local initial_position : Position* + alloc_locals + + %{ ids.initial_position = __dict_manager.get_dict(ids.positions_dict)[ids.position_id] %} + + let (range_check_ptr, updated_position, funded_position, return_code) = update_position( + range_check_ptr=range_check_ptr, + position=initial_position, + request_public_key=request_public_key, + collateral_delta=collateral_delta, + synthetic_asset_id=synthetic_asset_id, + synthetic_delta=synthetic_delta, + global_funding_indices=global_funding_indices, + oracle_prices=oracle_prices, + general_config=general_config) + + # Even if update failed, we need to write the update. + dict_update{dict_ptr=positions_dict}( + key=position_id, + prev_value=cast(initial_position, felt), + new_value=cast(updated_position, felt)) + + return ( + range_check_ptr=range_check_ptr, + positions_dict=positions_dict, + funded_position=funded_position, + updated_position=updated_position, + return_code=return_code) +end diff --git a/src/services/perpetual/cairo/position/validate_state_transition.cairo b/src/services/perpetual/cairo/position/validate_state_transition.cairo new file mode 100644 index 0000000..875c113 --- /dev/null +++ b/src/services/perpetual/cairo/position/validate_state_transition.cairo @@ -0,0 +1,83 @@ +from services.perpetual.cairo.definitions.constants import FXP_32_ONE +from services.perpetual.cairo.definitions.general_config import GeneralConfig, SyntheticAssetInfo +from services.perpetual.cairo.definitions.objects import OraclePrices +from services.perpetual.cairo.definitions.perpetual_error_code import PerpetualErrorCode +from services.perpetual.cairo.position.check_smaller_holdings import ( + check_smaller_in_synthetic_holdings) +from services.perpetual.cairo.position.position import Position +from services.perpetual.cairo.position.status import position_get_status +from starkware.cairo.common.math_cmp import is_le, is_le_felt + +# Checks if a position update was legal. +# A position update is legal if +# 1. The result position is well leveraged, or +# 2. a. The result position is `smaller` than the original position, and +# b. The ratio between the total_value and the total_risk in the result position is not +# smaller than the same ratio in the original position, and +# c. If the total risk of the original position is 0, the total value of the result +# position is not smaller than the total value of the original position. +func check_valid_transition( + range_check_ptr, updated_position : Position*, initial_position : Position*, + oracle_prices : OraclePrices*, general_config : GeneralConfig*) -> ( + range_check_ptr, return_code): + alloc_locals + let (range_check_ptr, local updated_tv, local updated_tr, return_code) = position_get_status( + range_check_ptr=range_check_ptr, + position=updated_position, + oracle_prices=oracle_prices, + general_config=general_config) + if return_code != PerpetualErrorCode.SUCCESS: + return (range_check_ptr=range_check_ptr, return_code=return_code) + end + + let (is_well_leveraged) = is_le{range_check_ptr=range_check_ptr}( + updated_tr, updated_tv * FXP_32_ONE) + if is_well_leveraged != 0: + return (range_check_ptr=range_check_ptr, return_code=PerpetualErrorCode.SUCCESS) + end + + let (range_check_ptr, local initial_tv, local initial_tr, return_code) = position_get_status( + range_check_ptr=range_check_ptr, + position=initial_position, + oracle_prices=oracle_prices, + general_config=general_config) + if return_code != PerpetualErrorCode.SUCCESS: + return (range_check_ptr=range_check_ptr, return_code=return_code) + end + + let (range_check_ptr, return_code) = check_smaller_in_synthetic_holdings( + range_check_ptr=range_check_ptr, + updated_position=updated_position, + initial_position=initial_position) + if return_code != PerpetualErrorCode.SUCCESS: + return (range_check_ptr=range_check_ptr, return_code=return_code) + end + + # total_value / total_risk must not decrease. + # tv0 / tr0 <= tv1 / tr1 iff tv0 * tr1 <= tv1 * tr0. + # tv is 96 bit. + # tr is 128 bit. + # tv*tr fits in 224 bits. + # Since tv can be negative, adding 2**224 to each side. + let (success) = is_le_felt{range_check_ptr=range_check_ptr}( + %[2**224%] + initial_tv * updated_tr, %[2**224%] + updated_tv * initial_tr) + + if success == 0: + let return_code = ( + PerpetualErrorCode.ILLEGAL_POSITION_TRANSITION_REDUCING_TOTAL_VALUE_RISK_RATIO) + return (range_check_ptr=range_check_ptr, return_code=return_code) + end + if initial_tr == 0: + # Edge case: When the total risk is 0 the TV/TR ratio is undefined and we need to check that + # initial_tv <= updated_tv. Note that because we passed + # 'check_smaller_in_synthetic_holdings' and initial_tr == 0 we must have updated_tr == 0. + let (success) = is_le{range_check_ptr=range_check_ptr}(initial_tv, updated_tv) + if success == 0: + return ( + range_check_ptr=range_check_ptr, + return_code=PerpetualErrorCode.ILLEGAL_POSITION_TRANSITION_NO_RISK_REDUCED_VALUE) + end + end + + return (range_check_ptr=range_check_ptr, return_code=PerpetualErrorCode.SUCCESS) +end diff --git a/src/services/perpetual/cairo/program_hash.json b/src/services/perpetual/cairo/program_hash.json new file mode 100644 index 0000000..9a3c878 --- /dev/null +++ b/src/services/perpetual/cairo/program_hash.json @@ -0,0 +1 @@ +{"program_hash": 3034699314518632418633798167565169540135516579166751375550467178691995202911} diff --git a/src/services/perpetual/cairo/program_hash_test.py b/src/services/perpetual/cairo/program_hash_test.py new file mode 100644 index 0000000..bcb4536 --- /dev/null +++ b/src/services/perpetual/cairo/program_hash_test.py @@ -0,0 +1,37 @@ +import json +import os + +from starkware.cairo.bootloader.hash_program import compute_program_hash_chain +from starkware.cairo.lang.compiler.program import Program +from starkware.python.utils import get_source_dir_path + +PROGRAM_PATH = os.path.join(os.path.dirname(__file__), 'perpetual_cairo_compiled.json') +HASH_PATH = get_source_dir_path('src/services/perpetual/cairo/program_hash.json') + + +def run_generate_hash_test(fix: bool): + compiled_program = Program.Schema().load(json.load(open(PROGRAM_PATH))) + program_hash = compute_program_hash_chain(compiled_program) + + if fix: + json.dump(obj={'program_hash': program_hash}, fp=open(HASH_PATH, 'w')) + else: + expected_hash = json.load(open(HASH_PATH))['program_hash'] + assert expected_hash == program_hash, \ + 'Wrong program hash in program_hash.json. ' + \ + 'Please run generate_perpetual_cairo_program_hash.' + + +def test_perpetual_program_hash(): + run_generate_hash_test(fix=False) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser( + description='Create or test the perpetual program hash.') + parser.add_argument( + '--fix', action='store_true', help='Fix the value of the program hash.') + + args = parser.parse_args() + run_generate_hash_test(fix=args.fix) diff --git a/src/services/perpetual/cairo/state/state.cairo b/src/services/perpetual/cairo/state/state.cairo new file mode 100644 index 0000000..0a8f30a --- /dev/null +++ b/src/services/perpetual/cairo/state/state.cairo @@ -0,0 +1,200 @@ +from services.perpetual.cairo.definitions.general_config import GeneralConfig +from services.perpetual.cairo.definitions.objects import ( + FundingIndicesInfo, OraclePrice, OraclePrices, funding_indices_info_serialize, + oracle_price_serialize) +from services.perpetual.cairo.position.hash import hash_position_updates +from services.perpetual.cairo.position.position import Position +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.dict import dict_new +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.merkle_multi_update import merkle_multi_update +from starkware.cairo.common.registers import get_fp_and_pc, get_label_location +from starkware.cairo.common.serialize import serialize_array, serialize_word +from starkware.cairo.common.squash_dict import squash_dict + +# State carried through batch execution. Keeps the current pointer of all dicts. +struct CarriedState: + member positions_dict : DictAccess* + member orders_dict : DictAccess* + member global_funding_indices : FundingIndicesInfo* + member oracle_prices : OraclePrices* + member system_time : felt +end + +func carried_state_new( + positions_dict : DictAccess*, orders_dict : DictAccess*, + global_funding_indices : FundingIndicesInfo*, oracle_prices : OraclePrices*, + system_time) -> (carried_state : CarriedState*): + let (fp_val, pc_val) = get_fp_and_pc() + return (carried_state=cast(fp_val - 2 - CarriedState.SIZE, CarriedState*)) +end + +# Carried state that keeps the squashed dicts. +struct SquashedCarriedState: + member positions_dict : DictAccess* + member positions_dict_end : DictAccess* + member orders_dict : DictAccess* + member orders_dict_end : DictAccess* + member global_funding_indices : FundingIndicesInfo* + member oracle_prices : OraclePrices* + member system_time : felt +end + +func squashed_carried_state_new( + positions_dict : DictAccess*, positions_dict_end : DictAccess*, orders_dict : DictAccess*, + orders_dict_end : DictAccess*, global_funding_indices : FundingIndicesInfo*, + oracle_prices : OraclePrices*, system_time) -> (carried_state : SquashedCarriedState*): + let (fp_val, pc_val) = get_fp_and_pc() + return (carried_state=cast(fp_val - 2 - SquashedCarriedState.SIZE, SquashedCarriedState*)) +end + +func carried_state_squash{range_check_ptr}( + initial_carried_state : CarriedState*, carried_state : CarriedState*) -> ( + squashed_carried_state : SquashedCarriedState*): + alloc_locals + # Squash positions dict. + local squashed_positions_dict : DictAccess* + local squashed_positions_dict_end : DictAccess* + %{ ids.squashed_positions_dict = segments.add() %} + let (squashed_positions_dict_end_) = squash_dict( + dict_accesses=initial_carried_state.positions_dict, + dict_accesses_end=carried_state.positions_dict, + squashed_dict=squashed_positions_dict) + squashed_positions_dict_end = squashed_positions_dict_end_ + # Squash orders dict. + local squashed_orders_dict : DictAccess* + local squashed_orders_dict_end : DictAccess* + %{ ids.squashed_orders_dict = segments.add() %} + let (squashed_orders_dict_end_) = squash_dict( + dict_accesses=initial_carried_state.orders_dict, + dict_accesses_end=carried_state.orders_dict, + squashed_dict=squashed_orders_dict) + squashed_orders_dict_end = squashed_orders_dict_end_ + # Return SquashedCarriedState. + let (squashed_carried_state) = squashed_carried_state_new( + positions_dict=squashed_positions_dict, + positions_dict_end=squashed_positions_dict_end, + orders_dict=squashed_orders_dict, + orders_dict_end=squashed_orders_dict_end, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + system_time=carried_state.system_time) + return (squashed_carried_state=squashed_carried_state) +end + +# State stored on the blockchain. +struct SharedState: + member positions_root : felt + member positions_tree_height : felt + member orders_root : felt + member orders_tree_height : felt + member global_funding_indices : FundingIndicesInfo* + member oracle_prices : OraclePrices* + member system_time : felt +end + +func shared_state_new( + positions_root, positions_tree_height, orders_root, orders_tree_height, + global_funding_indices : FundingIndicesInfo*, oracle_prices : OraclePrices*, + system_time) -> (carried_state : SharedState*): + let (fp_val, pc_val) = get_fp_and_pc() + return (carried_state=cast(fp_val - 2 - SharedState.SIZE, SharedState*)) +end + +# Applies the updates from the squashed carried state on the initial shared state. +# Arguments: +# hash_ptr - Pointer to the hash builtin. +# shared_state - The initial shared state +# squashed_carried_state - The squashed carried state representing the updated state. +# general_config - The general config (It doesn't change throughout the program so it's both initial +# and updated). +# +# Returns: +# hash_ptr - Pointer to the hash builtin. +# shared_state - The shared state that corresponds to the updated state. +func shared_state_apply_state_updates( + hash_ptr : HashBuiltin*, shared_state : SharedState*, + squashed_carried_state : SquashedCarriedState*, general_config : GeneralConfig*) -> ( + hash_ptr : HashBuiltin*, shared_state : SharedState*): + alloc_locals + + # Hash position updates. + local n_position_updates = (squashed_carried_state.positions_dict_end - squashed_carried_state.positions_dict) / DictAccess.SIZE + let (hashed_position_updates_ptr) = hash_position_updates{pedersen_ptr=hash_ptr}( + update_ptr=squashed_carried_state.positions_dict, n_updates=n_position_updates) + + # Merkle update positions dict. + local new_positions_root + %{ ids.new_positions_root = new_positions_root %} + with hash_ptr: + merkle_multi_update( + update_ptr=hashed_position_updates_ptr, + n_updates=n_position_updates, + height=general_config.positions_tree_height, + prev_root=shared_state.positions_root, + new_root=new_positions_root) + # Merkle update orders dict. + local new_orders_root + %{ ids.new_orders_root = new_orders_root %} + merkle_multi_update( + update_ptr=squashed_carried_state.orders_dict, + n_updates=(squashed_carried_state.orders_dict_end - squashed_carried_state.orders_dict) / DictAccess.SIZE, + height=general_config.orders_tree_height, + prev_root=shared_state.orders_root, + new_root=new_orders_root) + end + + # Return SharedState. + let (shared_state) = shared_state_new( + positions_root=new_positions_root, + positions_tree_height=general_config.positions_tree_height, + orders_root=new_orders_root, + orders_tree_height=general_config.orders_tree_height, + global_funding_indices=squashed_carried_state.global_funding_indices, + oracle_prices=squashed_carried_state.oracle_prices, + system_time=squashed_carried_state.system_time) + return (hash_ptr=hash_ptr, shared_state=shared_state) +end + +func shared_state_serialize{output_ptr : felt*}(shared_state : SharedState*): + alloc_locals + local output_start_ptr : felt* = output_ptr + # Storing an empty slot for the size of the structure which will be filled later in the code. + # A single slot due to the implementation of serialize_word which increments the ptr by one. + let output_ptr = output_ptr + 1 + serialize_word(shared_state.positions_root) + serialize_word(shared_state.positions_tree_height) + serialize_word(shared_state.orders_root) + serialize_word(shared_state.orders_tree_height) + funding_indices_info_serialize(shared_state.global_funding_indices) + let (callback_adddress) = get_label_location(label_value=oracle_price_serialize) + serialize_array( + array=shared_state.oracle_prices.data, + n_elms=shared_state.oracle_prices.len, + elm_size=OraclePrice.SIZE, + callback=callback_adddress) + serialize_word(shared_state.system_time) + let size = cast(output_ptr, felt) - cast(output_start_ptr, felt) - 1 + serialize_word{output_ptr=output_start_ptr}(size) + return () +end + +# Converts a shared state into a carried state. +# Arguments: +# shared_state - The current shared state. +# +# Hint Arguments: +# positions_dict - A dict mapping between a position id and its position. +# orders_dict - A dict mapping between an order id and its order's state. +func shared_state_to_carried_state(shared_state : SharedState*) -> (carried_state : CarriedState*): + %{ initial_dict = positions_dict %} + let (positions_dict : DictAccess*) = dict_new() + %{ initial_dict = orders_dict %} + let (orders_dict : DictAccess*) = dict_new() + return carried_state_new( + positions_dict=positions_dict, + orders_dict=orders_dict, + global_funding_indices=shared_state.global_funding_indices, + oracle_prices=shared_state.oracle_prices, + system_time=shared_state.system_time) +end diff --git a/src/services/perpetual/cairo/transactions/batch_config.cairo b/src/services/perpetual/cairo/transactions/batch_config.cairo new file mode 100644 index 0000000..2432716 --- /dev/null +++ b/src/services/perpetual/cairo/transactions/batch_config.cairo @@ -0,0 +1,19 @@ +from services.perpetual.cairo.definitions.general_config import GeneralConfig +from services.perpetual.cairo.definitions.objects import OraclePrice +from starkware.cairo.common.registers import get_fp_and_pc + +struct BatchConfig: + member general_config : GeneralConfig* + member signed_min_oracle_prices : OraclePrice* + member signed_max_oracle_prices : OraclePrice* + member n_oracle_prices : felt + member min_expiration_timestamp : felt +end + +func batch_config_new( + general_config : GeneralConfig*, signed_min_oracle_prices : OraclePrice*, + signed_max_oracle_prices : OraclePrice*, n_oracle_prices, min_expiration_timestamp) -> ( + batch_config : BatchConfig*): + let (fp_val, pc_val) = get_fp_and_pc() + return (batch_config=cast(fp_val - 2 - BatchConfig.SIZE, BatchConfig*)) +end diff --git a/src/services/perpetual/cairo/transactions/conditional_transfer.cairo b/src/services/perpetual/cairo/transactions/conditional_transfer.cairo new file mode 100644 index 0000000..ea36327 --- /dev/null +++ b/src/services/perpetual/cairo/transactions/conditional_transfer.cairo @@ -0,0 +1,110 @@ +from services.exchange.cairo.order import OrderBase +from services.perpetual.cairo.definitions.constants import ( + AMOUNT_UPPER_BOUND, EXPIRATION_TIMESTAMP_UPPER_BOUND, NONCE_UPPER_BOUND, + POSITION_ID_UPPER_BOUND) +from services.perpetual.cairo.definitions.general_config import GeneralConfig +from services.perpetual.cairo.definitions.perpetual_error_code import ( + PerpetualErrorCode, assert_success) +from services.perpetual.cairo.order.order import validate_order_and_update_fulfillment +from services.perpetual.cairo.output.program_output import PerpetualOutputs, perpetual_outputs_new +from services.perpetual.cairo.position.update_position import ( + NO_SYNTHETIC_DELTA_ASSET_ID, update_position_in_dict) +from services.perpetual.cairo.state.state import CarriedState, carried_state_new +from services.perpetual.cairo.transactions.batch_config import BatchConfig +from services.perpetual.cairo.transactions.transfer import Transfer, transfer_hash +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.math import assert_nn_le, assert_not_equal + +struct ConditionalTransfer: + member transfer : Transfer* + member condition : felt +end + +func execute_conditional_transfer( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, batch_config : BatchConfig*, outputs : PerpetualOutputs*, + tx : ConditionalTransfer*) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*): + alloc_locals + %{ error_code = ids.PerpetualErrorCode.SAME_POSITION_ID %} + assert_not_equal(tx.transfer.sender_position_id, tx.transfer.receiver_position_id) + %{ error_code = ids.PerpetualErrorCode.OUT_OF_RANGE_AMOUNT %} + assert_nn_le{range_check_ptr=range_check_ptr}(tx.transfer.amount, AMOUNT_UPPER_BOUND - 1) + %{ del error_code %} + local range_check_ptr = range_check_ptr + # expiration_timestamp and nonce will be validated in validate_order_and_update_fulfillment. + # Asset id is in range because we check that it's equal to the collateral asset id. + # Sender/Reciever's position id will be validated by update_position_in_dict. + + local general_config : GeneralConfig* = batch_config.general_config + # Validate that asset is collateral. + %{ error_code = ids.PerpetualErrorCode.INVALID_COLLATERAL_ASSET_ID %} + assert tx.transfer.asset_id = general_config.collateral_asset_info.asset_id + %{ del error_code %} + + let (local pedersen_ptr, message_hash) = transfer_hash( + pedersen_ptr=pedersen_ptr, transfer=tx.transfer, condition=tx.condition) + + let (range_check_ptr, local ecdsa_ptr, + local orders_dict) = validate_order_and_update_fulfillment( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + orders_dict=carried_state.orders_dict, + message_hash=message_hash, + order=tx.transfer.base, + min_expiration_timestamp=batch_config.min_expiration_timestamp, + update_amount=tx.transfer.amount, + full_amount=tx.transfer.amount) + + # Update the sender position. + let (range_check_ptr, positions_dict, _, _, return_code) = update_position_in_dict( + range_check_ptr=range_check_ptr, + positions_dict=carried_state.positions_dict, + position_id=tx.transfer.sender_position_id, + request_public_key=tx.transfer.base.public_key, + collateral_delta=-tx.transfer.amount, + synthetic_asset_id=NO_SYNTHETIC_DELTA_ASSET_ID, + synthetic_delta=0, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + general_config=general_config) + assert_success(return_code) + + # Update the receiver position. + let (range_check_ptr, positions_dict, _, _, return_code) = update_position_in_dict( + range_check_ptr=range_check_ptr, + positions_dict=positions_dict, + position_id=tx.transfer.receiver_position_id, + request_public_key=tx.transfer.receiver_public_key, + collateral_delta=tx.transfer.amount, + synthetic_asset_id=NO_SYNTHETIC_DELTA_ASSET_ID, + synthetic_delta=0, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + general_config=general_config) + assert_success(return_code) + + let (carried_state) = carried_state_new( + positions_dict=positions_dict, + orders_dict=orders_dict, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + system_time=carried_state.system_time) + + # Output the condition. + assert [outputs.conditions_ptr] = tx.condition + let (outputs : PerpetualOutputs*) = perpetual_outputs_new( + modifications_ptr=outputs.modifications_ptr, + forced_actions_ptr=outputs.forced_actions_ptr, + conditions_ptr=outputs.conditions_ptr + 1, + funding_indices_table_ptr=outputs.funding_indices_table_ptr) + + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) +end diff --git a/src/services/perpetual/cairo/transactions/deleverage.cairo b/src/services/perpetual/cairo/transactions/deleverage.cairo new file mode 100644 index 0000000..88a4c52 --- /dev/null +++ b/src/services/perpetual/cairo/transactions/deleverage.cairo @@ -0,0 +1,150 @@ +from services.perpetual.cairo.definitions.constants import AMOUNT_UPPER_BOUND, FXP_32_ONE +from services.perpetual.cairo.definitions.objects import FundingIndicesInfo +from services.perpetual.cairo.definitions.perpetual_error_code import ( + PerpetualErrorCode, assert_success) +from services.perpetual.cairo.output.program_output import PerpetualOutputs +from services.perpetual.cairo.position.position import ( + Position, position_add_collateral, position_get_asset_balance) +from services.perpetual.cairo.position.status import position_get_status +from services.perpetual.cairo.position.update_position import update_position_in_dict +from services.perpetual.cairo.state.state import CarriedState, carried_state_new +from services.perpetual.cairo.transactions.batch_config import BatchConfig +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.dict import dict_update +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.math import assert_le_250_bit, assert_lt, assert_nn_le, assert_not_equal + +struct Deleverage: + member deleveragable_position_id : felt + member deleverager_position_id : felt + member synthetic_asset_id : felt + member amount_synthetic : felt + member amount_collateral : felt + member deleverager_is_buying_synthetic : felt +end + +func execute_deleverage( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, batch_config : BatchConfig*, outputs : PerpetualOutputs*, + tx : Deleverage*) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*): + alloc_locals + + assert_nn_le{range_check_ptr=range_check_ptr}(tx.amount_synthetic, AMOUNT_UPPER_BOUND) + assert_nn_le{range_check_ptr=range_check_ptr}(tx.amount_collateral, AMOUNT_UPPER_BOUND) + + %{ error_code = ids.PerpetualErrorCode.SAME_POSITION_ID %} + # Assert that the deleverager position and the deleveragable position are distinct. + assert_not_equal(tx.deleverager_position_id, tx.deleveragable_position_id) + %{ del error_code %} + + local global_funding_indices : FundingIndicesInfo* = carried_state.global_funding_indices + + local deleverager_synthetic_delta + local deleveragable_synthetic_delta + local deleverager_collateral_delta + local deleveragable_collateral_delta + + if tx.deleverager_is_buying_synthetic != 0: + assert deleverager_synthetic_delta = tx.amount_synthetic + assert deleveragable_synthetic_delta = -tx.amount_synthetic + assert deleverager_collateral_delta = -tx.amount_collateral + assert deleveragable_collateral_delta = tx.amount_collateral + else: + assert deleverager_synthetic_delta = -tx.amount_synthetic + assert deleveragable_synthetic_delta = tx.amount_synthetic + assert deleverager_collateral_delta = tx.amount_collateral + assert deleveragable_collateral_delta = -tx.amount_collateral + end + + # Performing the transaction on both positions first to get the funded positions. + let (range_check_ptr, positions_dict : DictAccess*, local deleveragable_funded_position, + local deleveragable_updated_position, return_code) = update_position_in_dict( + range_check_ptr=range_check_ptr, + positions_dict=carried_state.positions_dict, + position_id=tx.deleveragable_position_id, + request_public_key=0, + collateral_delta=deleveragable_collateral_delta, + synthetic_asset_id=tx.synthetic_asset_id, + synthetic_delta=deleveragable_synthetic_delta, + global_funding_indices=global_funding_indices, + oracle_prices=carried_state.oracle_prices, + general_config=batch_config.general_config) + assert_success(return_code) + + let (range_check_ptr, local positions_dict, deleverager_funded_position : Position*, _, + return_code) = update_position_in_dict( + range_check_ptr=range_check_ptr, + positions_dict=positions_dict, + position_id=tx.deleverager_position_id, + request_public_key=0, + collateral_delta=deleverager_collateral_delta, + synthetic_asset_id=tx.synthetic_asset_id, + synthetic_delta=deleverager_synthetic_delta, + global_funding_indices=global_funding_indices, + oracle_prices=carried_state.oracle_prices, + general_config=batch_config.general_config) + assert_success(return_code) + + # Validating that deleverager has enough synthetic (or minus synthetic) for the transaction. + let (range_check_ptr, deleverager_synthetic_balance) = position_get_asset_balance( + range_check_ptr=range_check_ptr, + position=deleverager_funded_position, + asset_id=tx.synthetic_asset_id) + + %{ error_code = ids.PerpetualErrorCode.ILLEGAL_POSITION_TRANSITION_ENLARGING_SYNTHETIC_HOLDINGS %} + if tx.deleverager_is_buying_synthetic != 0: + assert_nn_le{range_check_ptr=range_check_ptr}( + tx.amount_synthetic, -deleverager_synthetic_balance) + else: + assert_nn_le{range_check_ptr=range_check_ptr}( + tx.amount_synthetic, deleverager_synthetic_balance) + end + %{ del error_code %} + + # Check that deleveragable position is deleveragable. + let (range_check_ptr, local initial_tv, local initial_tr, return_code) = position_get_status( + range_check_ptr=range_check_ptr, + position=deleveragable_funded_position, + oracle_prices=carried_state.oracle_prices, + general_config=batch_config.general_config) + assert_success(return_code) + %{ error_code = ids.PerpetualErrorCode.UNDELEVERAGABLE_POSITION %} + assert_lt{range_check_ptr=range_check_ptr}(initial_tv, 0) + %{ del error_code %} + + # Validates that deleverage ratio for the deleverager is the maximal it can be while being valid + # for the deleveragable. In other words, validates that if we reduce the collateral the + # deleveragable gets from the transaction by 1, the transaction is invalid. + let (range_check_ptr, updated_tv, updated_tr, return_code) = position_get_status( + range_check_ptr=range_check_ptr, + position=deleveragable_updated_position, + oracle_prices=carried_state.oracle_prices, + general_config=batch_config.general_config) + assert_success(return_code) + # tv0 / tr0 > tv1 / tr1 <=> tv0 * tr1 > tv1 * tr0. + # tv is 96 bit. + # tr is 128 bit. + # tv*tr fits in 224 bits. + # Since tv can be negative, adding 2**224 to each side. + %{ error_code = ids.PerpetualErrorCode.UNFAIR_DELEVERAGE %} + assert_le_250_bit{range_check_ptr=range_check_ptr}( + %[2**224%] + (updated_tv - FXP_32_ONE) * initial_tr + 1, + %[2**224%] + initial_tv * updated_tr) + %{ del error_code %} + + let (carried_state) = carried_state_new( + positions_dict=positions_dict, + orders_dict=carried_state.orders_dict, + global_funding_indices=global_funding_indices, + oracle_prices=carried_state.oracle_prices, + system_time=carried_state.system_time) + + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) +end diff --git a/src/services/perpetual/cairo/transactions/deposit.cairo b/src/services/perpetual/cairo/transactions/deposit.cairo new file mode 100644 index 0000000..beb376c --- /dev/null +++ b/src/services/perpetual/cairo/transactions/deposit.cairo @@ -0,0 +1,71 @@ +from services.perpetual.cairo.definitions.constants import AMOUNT_UPPER_BOUND +from services.perpetual.cairo.definitions.perpetual_error_code import ( + PerpetualErrorCode, assert_success) +from services.perpetual.cairo.output.program_output import ( + Modification, PerpetualOutputs, perpetual_outputs_new) +from services.perpetual.cairo.position.position import Position +from services.perpetual.cairo.position.update_position import ( + NO_SYNTHETIC_DELTA_ASSET_ID, update_position_in_dict) +from services.perpetual.cairo.state.state import CarriedState, carried_state_new +from services.perpetual.cairo.transactions.batch_config import BatchConfig +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.dict import dict_update +from starkware.cairo.common.math import assert_nn_le + +struct Deposit: + member public_key : felt + member position_id : felt + member amount : felt +end + +func execute_deposit( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, batch_config : BatchConfig*, outputs : PerpetualOutputs*, + tx : Deposit*) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*): + # Check validity of deposit. + # public_key has no constraints. + # position_id is validated implicitly by update_position_in_dict(). + %{ error_code = ids.PerpetualErrorCode.OUT_OF_RANGE_AMOUNT %} + assert_nn_le{range_check_ptr=range_check_ptr}(tx.amount, AMOUNT_UPPER_BOUND) + + %{ del error_code %} + let (range_check_ptr, positions_dict, _, _, return_code) = update_position_in_dict( + range_check_ptr=range_check_ptr, + positions_dict=carried_state.positions_dict, + position_id=tx.position_id, + request_public_key=tx.public_key, + collateral_delta=tx.amount, + synthetic_asset_id=NO_SYNTHETIC_DELTA_ASSET_ID, + synthetic_delta=0, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + general_config=batch_config.general_config) + assert_success(return_code) + + let (carried_state) = carried_state_new( + positions_dict=positions_dict, + orders_dict=carried_state.orders_dict, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + system_time=carried_state.system_time) + + # Write to output. + tempvar modification : Modification* = outputs.modifications_ptr + assert modification.public_key = tx.public_key + assert modification.position_id = tx.position_id + assert modification.biased_delta = tx.amount + AMOUNT_UPPER_BOUND + let (outputs : PerpetualOutputs*) = perpetual_outputs_new( + modifications_ptr=modification + Modification.SIZE, + forced_actions_ptr=outputs.forced_actions_ptr, + conditions_ptr=outputs.conditions_ptr, + funding_indices_table_ptr=outputs.funding_indices_table_ptr) + + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) +end diff --git a/src/services/perpetual/cairo/transactions/execute_limit_order.cairo b/src/services/perpetual/cairo/transactions/execute_limit_order.cairo new file mode 100644 index 0000000..91b4fbf --- /dev/null +++ b/src/services/perpetual/cairo/transactions/execute_limit_order.cairo @@ -0,0 +1,137 @@ +from services.perpetual.cairo.definitions.constants import ( + AMOUNT_UPPER_BOUND, EXPIRATION_TIMESTAMP_UPPER_BOUND, NONCE_UPPER_BOUND, + POSITIVE_AMOUNT_LOWER_BOUND) +from services.perpetual.cairo.definitions.general_config import GeneralConfig +from services.perpetual.cairo.definitions.perpetual_error_code import ( + PerpetualErrorCode, assert_success) +from services.perpetual.cairo.order.limit_order import LimitOrder, limit_order_hash +from services.perpetual.cairo.order.order import validate_order_and_update_fulfillment +from services.perpetual.cairo.order.validate_limit_order import validate_limit_order_fairness +from services.perpetual.cairo.position.update_position import ( + NO_SYNTHETIC_DELTA_ASSET_ID, update_position_in_dict) +from services.perpetual.cairo.state.state import CarriedState, carried_state_new +from services.perpetual.cairo.transactions.batch_config import BatchConfig +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.dict import dict_update +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.math import assert_in_range, assert_le, assert_nn_le, assert_not_equal + +# Executes a limit order of one party. Each trade will invoke this function twice, once per +# each party. +# A limit order is of the form: +# "I want to buy/sell up to 'amount_synthetic' synthetic for 'amount_collateral' collateral in the +# ratio amount_synthetic/amount_collateral (or better) and pay at most 'amount_fee' in fees." +# +# The actual amounts moved in this order are actual_collateral, actual_synthetic, actual_fee. +# The function charges a fee and adds it to fee_position. +# +# Assumption (for validate_limit_order_fairness): +# 0 <= actual_collateral < AMOUNT_UPPER_BOUND +# 0 <= actual_fee < AMOUNT_UPPER_BOUND +# AMOUNT_UPPER_BOUND**2 <= rc_bound. +# Fee doesn't have synthetic assets and cannot participate in an order. +func execute_limit_order( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, batch_config : BatchConfig*, limit_order : LimitOrder*, + actual_collateral, actual_synthetic, actual_fee) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*): + alloc_locals + local general_config : GeneralConfig* = batch_config.general_config + + assert_not_equal(limit_order.position_id, general_config.fee_position_info.position_id) + + # Check that asset_id_collateral is collateral. + %{ error_code = ids.PerpetualErrorCode.INVALID_COLLATERAL_ASSET_ID %} + assert limit_order.asset_id_collateral = general_config.collateral_asset_info.asset_id + # No need to delete error code because it is changed in the next line. + + # 0 < limit_order.amount_collateral < AMOUNT_UPPER_BOUND. + # 0 <= limit_order.amount_fee < AMOUNT_UPPER_BOUND. + # Note that limit_order.amount_synthetic is checked by validate_order_and_update_fulfillment. + %{ error_code = ids.PerpetualErrorCode.OUT_OF_RANGE_POSITIVE_AMOUNT %} + assert_in_range{range_check_ptr=range_check_ptr}( + limit_order.amount_collateral, POSITIVE_AMOUNT_LOWER_BOUND, AMOUNT_UPPER_BOUND) + %{ del error_code %} + assert_nn_le{range_check_ptr=range_check_ptr}(limit_order.amount_fee, AMOUNT_UPPER_BOUND - 1) + + # actual_synthetic > 0. To prevent replay. + # Note that actual_synthetic <= AMOUNT_UPPER_BOUND is checked in + # validate_order_and_update_fulfillment. + %{ error_code = ids.PerpetualErrorCode.OUT_OF_RANGE_POSITIVE_AMOUNT %} + assert_le{range_check_ptr=range_check_ptr}(POSITIVE_AMOUNT_LOWER_BOUND, actual_synthetic) + %{ del error_code %} + + let (range_check_ptr) = validate_limit_order_fairness( + range_check_ptr=range_check_ptr, + limit_order=limit_order, + actual_collateral=actual_collateral, + actual_synthetic=actual_synthetic, + actual_fee=actual_fee) + + # Note by using update_position_in_dict with limit_order.position_id we check that + # 0 <= limit_order.position_id < 2**POSITION_TREE_HEIGHT = POSITION_ID_UPPER_BOUND. + # The expiration_timestamp and nonce are validate in validate_order_and_update_fulfillment. + let (message_hash) = limit_order_hash{pedersen_ptr=pedersen_ptr}(limit_order=limit_order) + local pedersen_ptr : HashBuiltin* = pedersen_ptr + + let (local range_check_ptr, local ecdsa_ptr, + local orders_dict) = validate_order_and_update_fulfillment( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + orders_dict=carried_state.orders_dict, + message_hash=message_hash, + order=limit_order.base, + min_expiration_timestamp=batch_config.min_expiration_timestamp, + update_amount=actual_synthetic, + full_amount=limit_order.amount_synthetic) + + local collateral_delta + local synthetic_delta + if limit_order.is_buying_synthetic != 0: + assert collateral_delta = (-actual_collateral) - actual_fee + assert synthetic_delta = actual_synthetic + else: + assert collateral_delta = actual_collateral - actual_fee + assert synthetic_delta = -actual_synthetic + end + + let (range_check_ptr, positions_dict, _, _, return_code) = update_position_in_dict( + range_check_ptr=range_check_ptr, + positions_dict=carried_state.positions_dict, + position_id=general_config.fee_position_info.position_id, + request_public_key=general_config.fee_position_info.public_key, + collateral_delta=actual_fee, + synthetic_asset_id=NO_SYNTHETIC_DELTA_ASSET_ID, + synthetic_delta=0, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + general_config=general_config) + assert_success(return_code) + + let (range_check_ptr, positions_dict, _, _, return_code) = update_position_in_dict( + range_check_ptr=range_check_ptr, + positions_dict=positions_dict, + position_id=limit_order.position_id, + request_public_key=limit_order.base.public_key, + collateral_delta=collateral_delta, + synthetic_asset_id=limit_order.asset_id_synthetic, + synthetic_delta=synthetic_delta, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + general_config=general_config) + assert_success(return_code) + + let (carried_state) = carried_state_new( + positions_dict=positions_dict, + orders_dict=orders_dict, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + system_time=carried_state.system_time) + + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state) +end diff --git a/src/services/perpetual/cairo/transactions/forced_trade.cairo b/src/services/perpetual/cairo/transactions/forced_trade.cairo new file mode 100644 index 0000000..f6e3446 --- /dev/null +++ b/src/services/perpetual/cairo/transactions/forced_trade.cairo @@ -0,0 +1,203 @@ +from services.perpetual.cairo.definitions.constants import AMOUNT_UPPER_BOUND +from services.perpetual.cairo.definitions.general_config import GeneralConfig +from services.perpetual.cairo.definitions.perpetual_error_code import ( + PerpetualErrorCode, assert_success) +from services.perpetual.cairo.output.forced import ( + ForcedAction, ForcedActionType, ForcedTradeAction, forced_trade_action_new) +from services.perpetual.cairo.output.program_output import PerpetualOutputs, perpetual_outputs_new +from services.perpetual.cairo.position.position import Position +from services.perpetual.cairo.position.update_position import update_position +from services.perpetual.cairo.state.state import CarriedState, carried_state_new +from services.perpetual.cairo.transactions.batch_config import BatchConfig +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.dict import dict_update +from starkware.cairo.common.math import assert_nn_le, assert_not_equal + +struct ForcedTrade: + member public_key_a : felt + member public_key_b : felt + member position_id_a : felt + member position_id_b : felt + member synthetic_asset_id : felt + member amount_collateral : felt + member amount_synthetic : felt + member is_party_a_buying_synthetic : felt + member nonce : felt + member is_valid : felt +end + +func try_to_trade( + range_check_ptr, carried_state : CarriedState*, position_buyer : Position*, + position_seller : Position*, public_key_buyer, public_key_seller, synthetic_asset_id, + amount_collateral, amount_synthetic, general_config : GeneralConfig*) -> ( + range_check_ptr, position_buyer : Position*, position_seller : Position*, return_code): + alloc_locals + # update_position will return the funded position as the updated position if it failed. + let (range_check_ptr, local updated_position_buyer, local funded_position_buyer, + local return_code_a) = update_position( + range_check_ptr=range_check_ptr, + position=position_buyer, + request_public_key=public_key_buyer, + collateral_delta=-amount_collateral, + synthetic_asset_id=synthetic_asset_id, + synthetic_delta=amount_synthetic, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + general_config=general_config) + let (range_check_ptr, local updated_position_seller, local funded_position_seller, + local return_code_b) = update_position( + range_check_ptr=range_check_ptr, + position=position_seller, + request_public_key=public_key_seller, + collateral_delta=amount_collateral, + synthetic_asset_id=synthetic_asset_id, + synthetic_delta=-amount_synthetic, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + general_config=general_config) + + # Assumes that the return code for success is zero and all other error codes are positive. + if return_code_a + return_code_b == 0: + return ( + range_check_ptr=range_check_ptr, + position_buyer=updated_position_buyer, + position_seller=updated_position_seller, + return_code=PerpetualErrorCode.SUCCESS) + end + local return_code + if return_code_a == 0: + return_code = return_code_b + else: + return_code = return_code_a + end + return ( + range_check_ptr=range_check_ptr, + position_buyer=funded_position_buyer, + position_seller=funded_position_seller, + return_code=return_code) +end + +# Executes a forced trade between two parties, where both partied agree on the exact trade details. +# The forced trade is requested by party_a onchain, and is signed by party_b (verified onchain). +# The forced trade can be specified as false forced trade with the is_valid member. The following +# assumptions are made on the transaction and aren't guaranteed to be accepted on a false trade if +# they aren't met: +# 1. Position id is in range. +# 2. The trade is between two different positions. +# 3. The collateral asset id is indeed collateral. +# 4. The synthetic asset id is in the configuration. +# 5. The amounts are in range. +# 6. The nonce is in range. +func execute_forced_trade( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, batch_config : BatchConfig*, outputs : PerpetualOutputs*, + tx : ForcedTrade*) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*): + alloc_locals + + # Check fields. + # public_keys are valid since they come from the previous position. + # position_ids are verified by being a dict key. + # synthetic_asset_id is verified in add_asset. + # Note: We don't mind failing here even when is_valid is 0, since the amount and the position_id + # should be verified on-chain when a user makes the forced action request. + assert_nn_le{range_check_ptr=range_check_ptr}(tx.amount_collateral, AMOUNT_UPPER_BOUND) + assert_nn_le{range_check_ptr=range_check_ptr}(tx.amount_synthetic, AMOUNT_UPPER_BOUND) + %{ error_code = ids.PerpetualErrorCode.SAME_POSITION_ID %} + assert_not_equal(tx.position_id_a, tx.position_id_b) + + # Read both positions. + local position_a : Position* + local position_b : Position* + let positions_dict = carried_state.positions_dict + %{ + del error_code + ids.position_a = __dict_manager.get_dict(ids.positions_dict)[ids.tx.position_id_a] + ids.position_b = __dict_manager.get_dict(ids.positions_dict)[ids.tx.position_id_b] + %} + + local new_position_a : Position* + local new_position_b : Position* + # Try to update the position. + if tx.is_party_a_buying_synthetic != 0: + let (range_check_ptr, new_position_buyer, new_position_seller, return_code) = try_to_trade( + range_check_ptr=range_check_ptr, + carried_state=carried_state, + position_buyer=position_a, + position_seller=position_b, + public_key_buyer=tx.public_key_a, + public_key_seller=tx.public_key_b, + synthetic_asset_id=tx.synthetic_asset_id, + amount_collateral=tx.amount_collateral, + amount_synthetic=tx.amount_synthetic, + general_config=batch_config.general_config) + new_position_a = new_position_buyer + new_position_b = new_position_seller + else: + let (range_check_ptr, new_position_buyer, new_position_seller, return_code) = try_to_trade( + range_check_ptr=range_check_ptr, + carried_state=carried_state, + position_buyer=position_b, + position_seller=position_a, + public_key_buyer=tx.public_key_b, + public_key_seller=tx.public_key_a, + synthetic_asset_id=tx.synthetic_asset_id, + amount_collateral=tx.amount_collateral, + amount_synthetic=tx.amount_synthetic, + general_config=batch_config.general_config) + new_position_a = new_position_seller + new_position_b = new_position_buyer + end + + local range_check_ptr = range_check_ptr + if tx.is_valid != 0: + assert_success(return_code) + else: + assert_not_equal(return_code, PerpetualErrorCode.SUCCESS) + end + + # Update positions. + dict_update{dict_ptr=positions_dict}( + key=tx.position_id_a, + prev_value=cast(position_a, felt), + new_value=cast(new_position_a, felt)) + dict_update{dict_ptr=positions_dict}( + key=tx.position_id_b, + prev_value=cast(position_b, felt), + new_value=cast(new_position_b, felt)) + + let (carried_state) = carried_state_new( + positions_dict=positions_dict, + orders_dict=carried_state.orders_dict, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + system_time=carried_state.system_time) + + # Write the forced action to output. + let forced_action : ForcedAction* = outputs.forced_actions_ptr + assert forced_action.forced_type = ForcedActionType.FORCED_TRADE + let (forced_trade_action) = forced_trade_action_new( + public_key_a=tx.public_key_a, + public_key_b=tx.public_key_b, + position_id_a=tx.position_id_a, + position_id_b=tx.position_id_b, + synthetic_asset_id=tx.synthetic_asset_id, + amount_collateral=tx.amount_collateral, + amount_synthetic=tx.amount_synthetic, + is_party_a_buying_synthetic=tx.is_party_a_buying_synthetic, + nonce=tx.nonce) + assert forced_action.forced_action = cast(forced_trade_action, felt*) + + let (outputs) = perpetual_outputs_new( + modifications_ptr=outputs.modifications_ptr, + forced_actions_ptr=outputs.forced_actions_ptr + ForcedAction.SIZE, + conditions_ptr=outputs.conditions_ptr, + funding_indices_table_ptr=outputs.funding_indices_table_ptr) + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) +end diff --git a/src/services/perpetual/cairo/transactions/forced_withdrawal.cairo b/src/services/perpetual/cairo/transactions/forced_withdrawal.cairo new file mode 100644 index 0000000..1ae98bd --- /dev/null +++ b/src/services/perpetual/cairo/transactions/forced_withdrawal.cairo @@ -0,0 +1,96 @@ +from services.perpetual.cairo.definitions.constants import AMOUNT_UPPER_BOUND +from services.perpetual.cairo.definitions.perpetual_error_code import ( + PerpetualErrorCode, assert_success) +from services.perpetual.cairo.output.forced import ( + ForcedAction, ForcedActionType, ForcedWithdrawalAction, forced_withdrawal_action_new) +from services.perpetual.cairo.output.program_output import ( + Modification, PerpetualOutputs, perpetual_outputs_new) +from services.perpetual.cairo.position.update_position import ( + NO_SYNTHETIC_DELTA_ASSET_ID, update_position_in_dict) +from services.perpetual.cairo.state.state import CarriedState, carried_state_new +from services.perpetual.cairo.transactions.batch_config import BatchConfig +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.math import assert_nn_le, assert_not_equal + +struct ForcedWithdrawal: + member public_key : felt + member position_id : felt + member amount : felt + member is_valid : felt +end + +func execute_forced_withdrawal( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, batch_config : BatchConfig*, outputs : PerpetualOutputs*, + tx : ForcedWithdrawal*) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*): + alloc_locals + + # Check fields. + # public_key is valid since it comes from the previous position. + # position_id is verified by being a dict key. + # Note: We don't mind failing here even when is_valid is 0, since the amount and the position_id + # should be verified on-chain when a user makes the forced action request. + %{ error_code = ids.PerpetualErrorCode.OUT_OF_RANGE_AMOUNT %} + assert_nn_le{range_check_ptr=range_check_ptr}(tx.amount, AMOUNT_UPPER_BOUND) + %{ del error_code %} + + # Try to update the position. update_position_in_dict will not update the position if it fails. + let (local range_check_ptr, local positions_dict, _, _, return_code) = update_position_in_dict( + range_check_ptr=range_check_ptr, + positions_dict=carried_state.positions_dict, + position_id=tx.position_id, + request_public_key=tx.public_key, + collateral_delta=-tx.amount, + synthetic_asset_id=NO_SYNTHETIC_DELTA_ASSET_ID, + synthetic_delta=0, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + general_config=batch_config.general_config) + if tx.is_valid != 0: + assert_success(return_code) + else: + # Validate that the transition could not be completed. The types of failures that + # update_position_in_dict fails on are the types of failures that an invalid forced action + # can fail on. + assert_not_equal(return_code, PerpetualErrorCode.SUCCESS) + end + + let (local carried_state) = carried_state_new( + positions_dict=positions_dict, + orders_dict=carried_state.orders_dict, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + system_time=carried_state.system_time) + + # Write the forced action to output. + let forced_action : ForcedAction* = outputs.forced_actions_ptr + assert forced_action.forced_type = ForcedActionType.FORCED_WITHDRAWAL + let (forced_withdrawal_action) = forced_withdrawal_action_new( + public_key=tx.public_key, position_id=tx.position_id, amount=tx.amount) + assert forced_action.forced_action = cast(forced_withdrawal_action, felt*) + + if tx.is_valid != 0: + # Also output a modification. + tempvar modification : Modification* = outputs.modifications_ptr + assert modification.public_key = tx.public_key + assert modification.position_id = tx.position_id + assert modification.biased_delta = AMOUNT_UPPER_BOUND - tx.amount + tempvar modifications_ptr = modification + Modification.SIZE + else: + tempvar modifications_ptr = outputs.modifications_ptr + end + + let (outputs) = perpetual_outputs_new( + modifications_ptr=modifications_ptr, + forced_actions_ptr=outputs.forced_actions_ptr + ForcedAction.SIZE, + conditions_ptr=outputs.conditions_ptr, + funding_indices_table_ptr=outputs.funding_indices_table_ptr) + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) +end diff --git a/src/services/perpetual/cairo/transactions/funding_tick.cairo b/src/services/perpetual/cairo/transactions/funding_tick.cairo new file mode 100644 index 0000000..2eab51c --- /dev/null +++ b/src/services/perpetual/cairo/transactions/funding_tick.cairo @@ -0,0 +1,245 @@ +from services.perpetual.cairo.definitions.constants import ( + ASSET_ID_UPPER_BOUND, FUNDING_INDEX_LOWER_BOUND, FUNDING_INDEX_UPPER_BOUND, FXP_32_ONE) +from services.perpetual.cairo.definitions.general_config import GeneralConfig +from services.perpetual.cairo.definitions.objects import ( + FundingIndicesInfo, OraclePrice, OraclePrices) +from services.perpetual.cairo.definitions.perpetual_error_code import PerpetualErrorCode +from services.perpetual.cairo.output.program_output import PerpetualOutputs, perpetual_outputs_new +from services.perpetual.cairo.position.funding import FundingIndex +from services.perpetual.cairo.state.state import CarriedState, carried_state_new +from services.perpetual.cairo.transactions.batch_config import BatchConfig +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.find_element import find_element +from starkware.cairo.common.math import ( + abs_value, assert_in_range, assert_le, assert_le_250_bit, assert_not_equal) +from starkware.cairo.common.registers import get_fp_and_pc + +struct FundingTick: + member global_funding_indices : FundingIndicesInfo* +end + +# Validate that funding index diff isn't too large. In other words, that: +# abs(change_in_funding_index) <= max_funding_rate * price * change_in_timestamp. +func validate_funding_index_diff_in_range( + range_check_ptr, max_funding_rate, funding_index_diff, timestamp_diff, price) -> ( + range_check_ptr): + let (funding_index_diff) = abs_value{range_check_ptr=range_check_ptr}(funding_index_diff) + # Using 250 bit version here because the second argument can be up to 2**160. + assert_le_250_bit{range_check_ptr=range_check_ptr}( + funding_index_diff * FXP_32_ONE, max_funding_rate * price * timestamp_diff) + return (range_check_ptr) +end + +# Arguments to validate_funding_tick_inner that remain constant throughout the recursive call. +struct ValidateFundingTickInnerArgs: + member max_funding_rate : felt + member timestamp_diff : felt +end + +func validate_funding_tick_inner_args_new(max_funding_rate, timestamp_diff) -> ( + args : ValidateFundingTickInnerArgs*): + let (fp_val, pc_val) = get_fp_and_pc() + return ( + args=cast(fp_val - 2 - ValidateFundingTickInnerArgs.SIZE, ValidateFundingTickInnerArgs*)) +end + +# Validates the funding tick recursively. Refer to the documentation of `validate_funding_tick` for +# the conditions a valid funding tick needs to hold. +# Each recursive call will advance new_funding_index_ptr and oracle_price_ptr until they point to an +# object with an asset_id that matches prev_funding_index_ptr. If no corresponding asset_id is found +# the function will fail. Then it will validate the funding diff according to the price (see +# 'validate_funding_index_diff_in_range'). After that it will advance prev_funding_index_ptr and +# repeat. +func validate_funding_tick_inner( + range_check_ptr, prev_funding_index_ptr : FundingIndex*, + new_funding_index_ptr : FundingIndex*, oracle_price_ptr : OraclePrice*, + last_new_funding_asset_id, args : ValidateFundingTickInnerArgs*) -> ( + range_check_ptr, prev_funding_index_ptr : FundingIndex*, + new_funding_index_ptr : FundingIndex*, oracle_price_ptr : OraclePrice*): + alloc_locals + local should_continue + local should_advance_oracle_price + local should_advance_new_funding_index + %{ + # Decide non-deterministically whether to advance oracle_price_ptr, new_funding_index_ptr. + # Also decide if we checked all the assets. + # validate_funding_tick will ensure that the final pointers are equal to the end pointers. + # This is sound because prev_funding_index_ptr will not advance until the other 2 pointers + # will have its asset id. + is_prev_funding_index_done = \ + ids.prev_funding_index_ptr.address_ == prev_funding_index_end.address_ + is_new_funding_index_done = \ + ids.new_funding_index_ptr.address_ == new_funding_index_end.address_ + is_oracle_price_done = ids.oracle_price_ptr.address_ == oracle_price_end.address_ + + prev_asset_id = \ + ids.prev_funding_index_ptr.asset_id if not is_prev_funding_index_done \ + else ids.ASSET_ID_UPPER_BOUND + new_asset_id = \ + ids.new_funding_index_ptr.asset_id if not is_new_funding_index_done \ + else ids.ASSET_ID_UPPER_BOUND + oracle_asset_id = \ + ids.oracle_price_ptr.asset_id if not is_oracle_price_done else ids.ASSET_ID_UPPER_BOUND + + ids.should_advance_new_funding_index = int(new_asset_id < prev_asset_id) + ids.should_advance_oracle_price = int(oracle_asset_id < prev_asset_id) + ids.should_continue = int( + not (is_prev_funding_index_done and is_new_funding_index_done and is_oracle_price_done)) + %} + if should_continue == 0: + assert_le{range_check_ptr=range_check_ptr}( + last_new_funding_asset_id + 1, ASSET_ID_UPPER_BOUND) + should_advance_oracle_price = should_advance_oracle_price + should_advance_new_funding_index = should_advance_new_funding_index + return ( + range_check_ptr=range_check_ptr, + prev_funding_index_ptr=prev_funding_index_ptr, + new_funding_index_ptr=new_funding_index_ptr, + oracle_price_ptr=oracle_price_ptr) + end + + # Since we need to validate that prev_funding_indices is contained in new_funding_indices and + # in oracle_prices, we will advance new_funding_index_ptr and oracle_price_ptr until its asset + # id is equal to current_asset_id. + if should_advance_oracle_price != 0: + should_advance_new_funding_index = should_advance_new_funding_index + return validate_funding_tick_inner( + range_check_ptr=range_check_ptr, + prev_funding_index_ptr=prev_funding_index_ptr, + new_funding_index_ptr=new_funding_index_ptr, + oracle_price_ptr=oracle_price_ptr + OraclePrice.SIZE, + last_new_funding_asset_id=last_new_funding_asset_id, + args=args) + end + + # We are always going to advance new_funding_index_ptr if we reached here. Therefore we will now + # check that its asset_id is larger than its previous asset_id and that the index is in range. + assert_le{range_check_ptr=range_check_ptr}( + last_new_funding_asset_id + 1, new_funding_index_ptr.asset_id) + %{ error_code = ids.PerpetualErrorCode.OUT_OF_RANGE_FUNDING_INDEX %} + assert_in_range{range_check_ptr=range_check_ptr}( + new_funding_index_ptr.funding_index, FUNDING_INDEX_LOWER_BOUND, FUNDING_INDEX_UPPER_BOUND) + %{ del error_code %} + + if should_advance_new_funding_index != 0: + return validate_funding_tick_inner( + range_check_ptr=range_check_ptr, + prev_funding_index_ptr=prev_funding_index_ptr, + new_funding_index_ptr=new_funding_index_ptr + FundingIndex.SIZE, + oracle_price_ptr=oracle_price_ptr, + last_new_funding_asset_id=new_funding_index_ptr.asset_id, + args=args) + end + + tempvar current_asset_id = prev_funding_index_ptr.asset_id + current_asset_id = new_funding_index_ptr.asset_id + current_asset_id = oracle_price_ptr.asset_id + + # Now all asset ids are equal. We need to check the rate of the funding change. + # If we are here, then prev_funding_index_ptr hasn't reached its end. + let (range_check_ptr) = validate_funding_index_diff_in_range( + range_check_ptr=range_check_ptr, + max_funding_rate=args.max_funding_rate, + funding_index_diff=new_funding_index_ptr.funding_index - prev_funding_index_ptr.funding_index, + timestamp_diff=args.timestamp_diff, + price=oracle_price_ptr.price) + + return validate_funding_tick_inner( + range_check_ptr=range_check_ptr, + prev_funding_index_ptr=prev_funding_index_ptr + FundingIndex.SIZE, + new_funding_index_ptr=new_funding_index_ptr + FundingIndex.SIZE, + oracle_price_ptr=oracle_price_ptr + OraclePrice.SIZE, + last_new_funding_asset_id=new_funding_index_ptr.asset_id, + args=args) +end + +# Validate that: +# 1. prev_funding_indices is contained in new funding indices. +# 2. prev_funding_indices is contained in oracle prices. +# 3. new funding indices are in range for indices that are in prev_funding_indices. +# 4. new_funding_indices is sorted and has no duplicates. +func validate_funding_tick( + range_check_ptr, carried_state : CarriedState*, general_config : GeneralConfig*, + new_funding_indices : FundingIndicesInfo*) -> (range_check_ptr): + alloc_locals + tempvar prev_funding_indices : FundingIndicesInfo* = carried_state.global_funding_indices + tempvar oracle_prices : OraclePrices* = carried_state.oracle_prices + + let (args) = validate_funding_tick_inner_args_new( + max_funding_rate=general_config.max_funding_rate, + timestamp_diff=( + new_funding_indices.funding_timestamp - prev_funding_indices.funding_timestamp)) + + local prev_funding_index_end : FundingIndex* = ( + prev_funding_indices.funding_indices + + FundingIndex.SIZE * prev_funding_indices.n_funding_indices) + local new_funding_index_end : FundingIndex* = ( + new_funding_indices.funding_indices + + FundingIndex.SIZE * new_funding_indices.n_funding_indices) + local oracle_price_end : OraclePrice* = ( + oracle_prices.data + OraclePrice.SIZE * oracle_prices.len) + %{ + prev_funding_index_end = ids.prev_funding_index_end + new_funding_index_end = ids.new_funding_index_end + oracle_price_end = ids.oracle_price_end + %} + + let (range_check_ptr, returned_prev_funding_index_ptr, returned_new_funding_index_ptr, + returned_oracle_price_ptr) = validate_funding_tick_inner( + range_check_ptr=range_check_ptr, + prev_funding_index_ptr=prev_funding_indices.funding_indices, + new_funding_index_ptr=new_funding_indices.funding_indices, + oracle_price_ptr=oracle_prices.data, + last_new_funding_asset_id=-1, + args=args) + + # Validate that all the asset_ids that validate_funding_tick_inner() went through are all the + # asset ids there are. + prev_funding_index_end = returned_prev_funding_index_ptr + new_funding_index_end = returned_new_funding_index_ptr + oracle_price_end = returned_oracle_price_ptr + + return (range_check_ptr) +end + +func execute_funding_tick( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, batch_config : BatchConfig*, outputs : PerpetualOutputs*, + tx : FundingTick*) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*): + let new_funding_indices : FundingIndicesInfo* = tx.global_funding_indices + # Check that new timestamp is larger than previous system time. + # If signatures will be required to verify OraclePricesTick, then the timestamps for the + # oracle prices in the carried state will be verified here. + assert_le{range_check_ptr=range_check_ptr}( + carried_state.system_time, new_funding_indices.funding_timestamp) + + let (range_check_ptr) = validate_funding_tick( + range_check_ptr=range_check_ptr, + carried_state=carried_state, + general_config=batch_config.general_config, + new_funding_indices=new_funding_indices) + + let (new_carried_state) = carried_state_new( + positions_dict=carried_state.positions_dict, + orders_dict=carried_state.orders_dict, + global_funding_indices=new_funding_indices, + oracle_prices=carried_state.oracle_prices, + system_time=new_funding_indices.funding_timestamp) + + assert [outputs.funding_indices_table_ptr] = new_funding_indices + + let (outputs : PerpetualOutputs*) = perpetual_outputs_new( + modifications_ptr=outputs.modifications_ptr, + forced_actions_ptr=outputs.forced_actions_ptr, + conditions_ptr=outputs.conditions_ptr, + funding_indices_table_ptr=outputs.funding_indices_table_ptr + 1) + + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=new_carried_state, + outputs=outputs) +end diff --git a/src/services/perpetual/cairo/transactions/liquidate.cairo b/src/services/perpetual/cairo/transactions/liquidate.cairo new file mode 100644 index 0000000..f8399e9 --- /dev/null +++ b/src/services/perpetual/cairo/transactions/liquidate.cairo @@ -0,0 +1,147 @@ +from services.perpetual.cairo.definitions.constants import AMOUNT_UPPER_BOUND, FXP_32_ONE +from services.perpetual.cairo.definitions.general_config import GeneralConfig +from services.perpetual.cairo.definitions.objects import FundingIndicesInfo, OraclePrices +from services.perpetual.cairo.definitions.perpetual_error_code import ( + PerpetualErrorCode, assert_success) +from services.perpetual.cairo.order.limit_order import LimitOrder +from services.perpetual.cairo.output.program_output import PerpetualOutputs +from services.perpetual.cairo.position.funding import position_apply_funding +from services.perpetual.cairo.position.position import Position, position_get_asset_balance +from services.perpetual.cairo.position.status import position_get_status +from services.perpetual.cairo.position.update_position import update_position +from services.perpetual.cairo.state.state import CarriedState, carried_state_new +from services.perpetual.cairo.transactions.batch_config import BatchConfig +from services.perpetual.cairo.transactions.execute_limit_order import execute_limit_order +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.dict import dict_update +from starkware.cairo.common.math import ( + assert_in_range, assert_le_250_bit, assert_nn_le, assert_not_equal) + +struct Liquidate: + member liquidator_order : LimitOrder* + # liquidator_position_id = liquidator_order.position_id. + member liquidated_position_id : felt + member actual_collateral : felt + member actual_synthetic : felt + member actual_liquidator_fee : felt +end + +func execute_liquidate( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, batch_config : BatchConfig*, outputs : PerpetualOutputs*, + tx : Liquidate*) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*): + alloc_locals + local limit_order : LimitOrder* = tx.liquidator_order + local general_config : GeneralConfig* = batch_config.general_config + local oracle_prices : OraclePrices* = carried_state.oracle_prices + local global_funding_indices : FundingIndicesInfo* = carried_state.global_funding_indices + local synthetic_asset_id = limit_order.asset_id_synthetic + + local liquidated_position : Position* + + local collateral_delta + local synthetic_delta + + # Note that tx.actual_synthetic is checked in execute_limit_order. + assert_nn_le{range_check_ptr=range_check_ptr}(tx.actual_collateral, AMOUNT_UPPER_BOUND) + assert_nn_le{range_check_ptr=range_check_ptr}(tx.actual_liquidator_fee, AMOUNT_UPPER_BOUND) + + # Assert that the liquidator position and the liquidated position are distinct. + assert_not_equal(tx.liquidator_order.position_id, tx.liquidated_position_id) + + if limit_order.is_buying_synthetic == 0: + assert collateral_delta = -tx.actual_collateral + assert synthetic_delta = tx.actual_synthetic + else: + assert collateral_delta = tx.actual_collateral + assert synthetic_delta = -tx.actual_synthetic + end + + %{ + positions_dict = __dict_manager.get_dict(ids.carried_state.positions_dict) + ids.liquidated_position = positions_dict[ids.tx.liquidated_position_id] + %} + + let (range_check_ptr, local liquidated_funded_position) = position_apply_funding( + range_check_ptr=range_check_ptr, + position=liquidated_position, + global_funding_indices=global_funding_indices) + + # Check that liquidated position is liquidatable. + let (range_check_ptr, local updated_tv, local updated_tr, return_code) = position_get_status( + range_check_ptr=range_check_ptr, + position=liquidated_funded_position, + oracle_prices=oracle_prices, + general_config=general_config) + assert_success(return_code) + + # TR can be up to 2**128 and TV can be down to -2**95, Therefore we can't use assert_le. + %{ error_code = ids.PerpetualErrorCode.UNLIQUIDATABLE_POSITION %} + assert_le_250_bit{range_check_ptr=range_check_ptr}(updated_tv * FXP_32_ONE + 1, updated_tr) + %{ del error_code %} + + # We need to check that the synthetic balance in the liquidated position won't grow and will + # keep the same sign. + let (range_check_ptr, local initial_liquidated_asset_balance) = position_get_asset_balance( + range_check_ptr=range_check_ptr, + position=liquidated_funded_position, + asset_id=synthetic_asset_id) + + %{ error_code = ids.PerpetualErrorCode.ILLEGAL_POSITION_TRANSITION_ENLARGING_SYNTHETIC_HOLDINGS %} + if limit_order.is_buying_synthetic == 0: + # Initial_liquidated_asset_balance <= -synthetic_delta <= 0. + assert_in_range{range_check_ptr=range_check_ptr}( + -synthetic_delta, initial_liquidated_asset_balance, 1) + else: + # 0 <= -synthetic_delta <= initial_liquidated_asset_balance. + assert_nn_le{range_check_ptr=range_check_ptr}( + -synthetic_delta, initial_liquidated_asset_balance) + end + %{ del error_code %} + + # Updating the liquidated position. + let (range_check_ptr, liquidated_updated_position, _, return_code) = update_position( + range_check_ptr=range_check_ptr, + position=liquidated_funded_position, + request_public_key=liquidated_funded_position.public_key, + collateral_delta=collateral_delta, + synthetic_asset_id=synthetic_asset_id, + synthetic_delta=synthetic_delta, + global_funding_indices=global_funding_indices, + oracle_prices=oracle_prices, + general_config=general_config) + assert_success(return_code) + + let positions_dict = carried_state.positions_dict + dict_update{dict_ptr=positions_dict}( + key=tx.liquidated_position_id, + prev_value=cast(liquidated_position, felt), + new_value=cast(liquidated_updated_position, felt)) + + let (carried_state) = carried_state_new( + positions_dict=positions_dict, + orders_dict=carried_state.orders_dict, + global_funding_indices=global_funding_indices, + oracle_prices=oracle_prices, + system_time=carried_state.system_time) + + let (pedersen_ptr, range_check_ptr, ecdsa_ptr, carried_state) = execute_limit_order( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + batch_config=batch_config, + limit_order=limit_order, + actual_collateral=tx.actual_collateral, + actual_synthetic=tx.actual_synthetic, + actual_fee=tx.actual_liquidator_fee) + + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) +end diff --git a/src/services/perpetual/cairo/transactions/oracle_prices_tick.cairo b/src/services/perpetual/cairo/transactions/oracle_prices_tick.cairo new file mode 100644 index 0000000..b47fcd4 --- /dev/null +++ b/src/services/perpetual/cairo/transactions/oracle_prices_tick.cairo @@ -0,0 +1,175 @@ +from services.perpetual.cairo.definitions.constants import ASSET_ID_UPPER_BOUND +from services.perpetual.cairo.definitions.objects import ( + OraclePrice, OraclePrices, oracle_prices_new) +from services.perpetual.cairo.definitions.perpetual_error_code import PerpetualErrorCode +from services.perpetual.cairo.output.program_output import PerpetualOutputs +from services.perpetual.cairo.state.state import CarriedState, carried_state_new +from services.perpetual.cairo.transactions.batch_config import BatchConfig +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.find_element import find_element, search_sorted_lower +from starkware.cairo.common.math import assert_in_range, assert_le +from starkware.cairo.common.memcpy import memcpy +from starkware.cairo.common.registers import get_fp_and_pc + +# A tick containing oracle prices for assets. +# The tick does not contain signatures. Instead, at the start of each batch, signatures are verified +# for the minimal and maximal prices that appeared for each asset (which have the most potential to +# cause liquidations). Each OraclePricesTick is subsequently verified to be within this price range. +struct OraclePricesTick: + member oracle_prices : OraclePrices* + member timestamp : felt +end + +# Inserts into new_oracle_price_ptr all prices from given array with asset_id less than +# asset_id_bound. Returns the amount of prices inserted. +func insert_oracle_prices_until_asset_id( + range_check_ptr, oracle_price_ptr : OraclePrice*, n_oracle_prices, asset_id_bound, + new_oracle_price_ptr : OraclePrice*) -> (range_check_ptr, n_new_oracle_prices): + alloc_locals + static_assert OraclePrice.asset_id == 0 + let (oracle_price_lower_end : OraclePrice*) = search_sorted_lower{ + range_check_ptr=range_check_ptr}( + array_ptr=oracle_price_ptr, + elm_size=OraclePrice.SIZE, + n_elms=n_oracle_prices, + key=asset_id_bound) + local range_check_ptr = range_check_ptr + local oracle_prices_lower_size = (oracle_price_lower_end - oracle_price_ptr) + memcpy(dst=new_oracle_price_ptr, src=oracle_price_ptr, len=oracle_prices_lower_size) + return ( + range_check_ptr=range_check_ptr, + n_new_oracle_prices=oracle_prices_lower_size / OraclePrice.SIZE) +end + +# Creates a new oracle prices array that is sorted and contains all prices from +# prev_oracle_price_ptr and tick_price_ptr. If an asset has prices in both of them, the price from +# tick_price_ptr will be taken. Also validates that tick_price_ptr is sorted and its prices are in +# the range defined in batch_config. +# Returns the new range_check_ptr and the amount of oracle_prices inserted. +func create_new_oracle_prices_and_validate_tick( + range_check_ptr, prev_oracle_price_ptr : OraclePrice*, n_oracle_prices, + tick_price_ptr : OraclePrice*, n_tick_prices, last_tick_asset_id, + batch_config : BatchConfig*, new_oracle_price_ptr : OraclePrice*) -> ( + range_check_ptr, n_new_oracle_prices): + if n_tick_prices == 0: + assert_le{range_check_ptr=range_check_ptr}(last_tick_asset_id + 1, ASSET_ID_UPPER_BOUND) + return insert_oracle_prices_until_asset_id( + range_check_ptr=range_check_ptr, + oracle_price_ptr=prev_oracle_price_ptr, + n_oracle_prices=n_oracle_prices, + asset_id_bound=ASSET_ID_UPPER_BOUND, + new_oracle_price_ptr=new_oracle_price_ptr) + end + alloc_locals + + # Inserting into new_oracle_price_ptr all prices from prev_oracle_price_ptr with asset ids + # smaller than the asset id in tick_price_ptr. + let (range_check_ptr, local n_oracle_prices_inserted) = insert_oracle_prices_until_asset_id( + range_check_ptr=range_check_ptr, + oracle_price_ptr=prev_oracle_price_ptr, + n_oracle_prices=n_oracle_prices, + asset_id_bound=tick_price_ptr.asset_id, + new_oracle_price_ptr=new_oracle_price_ptr) + + %{ error_code = ids.PerpetualErrorCode.UNSORTED_ORACLE_PRICES %} + assert_le{range_check_ptr=range_check_ptr}(last_tick_asset_id + 1, tick_price_ptr.asset_id) + %{ del error_code %} + + # Asserting that price in tick is in the range defined in batch config. + let (min_oracle_price : OraclePrice*) = find_element{range_check_ptr=range_check_ptr}( + array_ptr=batch_config.signed_min_oracle_prices, + elm_size=OraclePrice.SIZE, + n_elms=batch_config.n_oracle_prices, + key=tick_price_ptr.asset_id) + + let (max_oracle_price : OraclePrice*) = find_element{range_check_ptr=range_check_ptr}( + array_ptr=batch_config.signed_max_oracle_prices, + elm_size=OraclePrice.SIZE, + n_elms=batch_config.n_oracle_prices, + key=tick_price_ptr.asset_id) + + assert_in_range{range_check_ptr=range_check_ptr}( + tick_price_ptr.price, min_oracle_price.price, max_oracle_price.price + 1) + local range_check_ptr = range_check_ptr + + # Advance prev_oracle_price_ptr by n_oracle_prices_inserted. + local prev_oracle_price_ptr : OraclePrice* = ( + prev_oracle_price_ptr + n_oracle_prices_inserted * OraclePrice.SIZE) + # If the asset id in tick_price_ptr exists in prev_oracle_price_ptr, advance + # prev_oracle_price_ptr by an extra 1. + local oracle_price_ptr1 : OraclePrice* + local n_oracle_prices1 + if n_oracle_prices != n_oracle_prices_inserted: + if prev_oracle_price_ptr.asset_id == tick_price_ptr.asset_id: + oracle_price_ptr1 = prev_oracle_price_ptr + OraclePrice.SIZE + assert n_oracle_prices1 = n_oracle_prices - n_oracle_prices_inserted - 1 + else: + oracle_price_ptr1 = prev_oracle_price_ptr + n_oracle_prices1 = n_oracle_prices - n_oracle_prices_inserted + end + else: + oracle_price_ptr1 = prev_oracle_price_ptr + n_oracle_prices1 = n_oracle_prices - n_oracle_prices_inserted + end + let prev_oracle_price_ptr : OraclePrice* = oracle_price_ptr1 + let n_oracle_prices = n_oracle_prices1 + + # Advance new_oracle_price_ptr by the amount of elements we inserted into it from + # prev_oracle_price_ptr. + let new_oracle_price_ptr = new_oracle_price_ptr + n_oracle_prices_inserted * OraclePrice.SIZE + + # Copy current tick asset into new_oracle_price_ptr. + memcpy(dst=new_oracle_price_ptr, src=tick_price_ptr, len=OraclePrice.SIZE) + let new_oracle_price_ptr = new_oracle_price_ptr + OraclePrice.SIZE + + let (range_check_ptr, n_new_oracle_prices) = create_new_oracle_prices_and_validate_tick( + range_check_ptr=range_check_ptr, + prev_oracle_price_ptr=prev_oracle_price_ptr, + n_oracle_prices=n_oracle_prices, + tick_price_ptr=tick_price_ptr + OraclePrice.SIZE, + n_tick_prices=n_tick_prices - 1, + last_tick_asset_id=tick_price_ptr.asset_id, + batch_config=batch_config, + new_oracle_price_ptr=new_oracle_price_ptr) + return ( + range_check_ptr=range_check_ptr, + n_new_oracle_prices=n_new_oracle_prices + n_oracle_prices_inserted + 1) +end + +func execute_oracle_prices_tick( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, batch_config : BatchConfig*, outputs : PerpetualOutputs*, + tx : OraclePricesTick*) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*): + alloc_locals + # Check that new timestamp is larger than previous system time. + assert_le{range_check_ptr=range_check_ptr}(carried_state.system_time, tx.timestamp) + + let (local new_oracle_price_ptr : OraclePrice*) = alloc() + local tick_prices : OraclePrices* = tx.oracle_prices + let (range_check_ptr, n_new_oracle_prices) = create_new_oracle_prices_and_validate_tick( + range_check_ptr=range_check_ptr, + prev_oracle_price_ptr=carried_state.oracle_prices.data, + n_oracle_prices=carried_state.oracle_prices.len, + tick_price_ptr=tick_prices.data, + n_tick_prices=tick_prices.len, + last_tick_asset_id=-1, + batch_config=batch_config, + new_oracle_price_ptr=new_oracle_price_ptr) + let (oracle_prices) = oracle_prices_new(len=n_new_oracle_prices, data=new_oracle_price_ptr) + + let (carried_state) = carried_state_new( + positions_dict=carried_state.positions_dict, + orders_dict=carried_state.orders_dict, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=oracle_prices, + system_time=tx.timestamp) + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) +end diff --git a/src/services/perpetual/cairo/transactions/trade.cairo b/src/services/perpetual/cairo/transactions/trade.cairo new file mode 100644 index 0000000..4a4a797 --- /dev/null +++ b/src/services/perpetual/cairo/transactions/trade.cairo @@ -0,0 +1,88 @@ +from services.perpetual.cairo.definitions.constants import AMOUNT_UPPER_BOUND +from services.perpetual.cairo.definitions.general_config import FeePositionInfo, GeneralConfig +from services.perpetual.cairo.definitions.perpetual_error_code import PerpetualErrorCode +from services.perpetual.cairo.order.limit_order import LimitOrder +from services.perpetual.cairo.output.program_output import PerpetualOutputs +from services.perpetual.cairo.position.position import Position +from services.perpetual.cairo.state.state import CarriedState, carried_state_new +from services.perpetual.cairo.transactions.batch_config import BatchConfig +from services.perpetual.cairo.transactions.execute_limit_order import execute_limit_order +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.dict import dict_update +from starkware.cairo.common.math import assert_in_range, assert_nn_le, assert_not_equal + +struct Trade: + # Party A is the party that buys synthetic and party B is the party that sells synthetic. + member party_a_order : LimitOrder* + member party_b_order : LimitOrder* + member actual_collateral : felt + member actual_synthetic : felt + member actual_a_fee : felt + member actual_b_fee : felt +end + +# Executes a trade between two parties, where both parties agree to a limit order +# and those orders match. +func execute_trade( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, batch_config : BatchConfig*, outputs : PerpetualOutputs*, + tx : Trade*) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*): + alloc_locals + + let trade : Trade* = tx + + # 0 <= trade.actual_collateral, trade.actual_a_fee, trade.actual_b_fee < AMOUNT_UPPER_BOUND. + # Note that actual_synthetic is checked in execute_limit_order. + assert_nn_le{range_check_ptr=range_check_ptr}(trade.actual_collateral, AMOUNT_UPPER_BOUND - 1) + assert_nn_le{range_check_ptr=range_check_ptr}(trade.actual_a_fee, AMOUNT_UPPER_BOUND - 1) + assert_nn_le{range_check_ptr=range_check_ptr}(trade.actual_b_fee, AMOUNT_UPPER_BOUND - 1) + + # Check that party A is buying synthetic and party B is selling synthetic. + local buy_order : LimitOrder* = trade.party_a_order + assert buy_order.is_buying_synthetic = 1 + + local sell_order : LimitOrder* = trade.party_b_order + assert sell_order.is_buying_synthetic = 0 + + # Execute_limit_order will verify that A and B are not the fee position. + let (pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*) = execute_limit_order( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + batch_config=batch_config, + limit_order=buy_order, + actual_collateral=trade.actual_collateral, + actual_synthetic=trade.actual_synthetic, + actual_fee=trade.actual_a_fee) + + # Check that orders match in asset id. + assert buy_order.asset_id_synthetic = sell_order.asset_id_synthetic + + # Check that orders' positions are distinct. + %{ error_code = ids.PerpetualErrorCode.SAME_POSITION_ID %} + assert_not_equal(buy_order.position_id, sell_order.position_id) + %{ del error_code %} + + let (pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*) = execute_limit_order( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + batch_config=batch_config, + limit_order=sell_order, + actual_collateral=trade.actual_collateral, + actual_synthetic=trade.actual_synthetic, + actual_fee=trade.actual_b_fee) + + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) +end diff --git a/src/services/perpetual/cairo/transactions/transaction.cairo b/src/services/perpetual/cairo/transactions/transaction.cairo new file mode 100644 index 0000000..148916f --- /dev/null +++ b/src/services/perpetual/cairo/transactions/transaction.cairo @@ -0,0 +1,23 @@ +namespace TransactionType: + const DEPOSIT = 0 + const FORCED_TRADE = 1 + const FORCED_WITHDRAWAL = 2 + const FUNDING_TICK = 3 + const ORACLE_PRICES_TICK = 4 + const TRADE = 5 + const TRANSFER = 6 + const LIQUIDATE = 7 + const WITHDRAWAL = 8 + const DELEVERAGE = 9 + const CONDITIONAL_TRANSFER = 10 +end + +struct Transaction: + member tx_type : felt + member tx : felt* +end + +struct Transactions: + member len : felt + member data : Transaction* +end diff --git a/src/services/perpetual/cairo/transactions/transfer.cairo b/src/services/perpetual/cairo/transactions/transfer.cairo new file mode 100644 index 0000000..41235d2 --- /dev/null +++ b/src/services/perpetual/cairo/transactions/transfer.cairo @@ -0,0 +1,142 @@ +from services.exchange.cairo.definitions.constants import VAULT_ID_UPPER_BOUND +from services.exchange.cairo.order import OrderBase +from services.exchange.cairo.signature_message_hashes import ExchangeTransfer +from services.exchange.cairo.signature_message_hashes import transfer_hash as exchange_transfer_hash +from services.perpetual.cairo.definitions.constants import ( + AMOUNT_UPPER_BOUND, POSITION_ID_UPPER_BOUND) +from services.perpetual.cairo.definitions.general_config import GeneralConfig +from services.perpetual.cairo.definitions.perpetual_error_code import ( + PerpetualErrorCode, assert_success) +from services.perpetual.cairo.order.order import validate_order_and_update_fulfillment +from services.perpetual.cairo.output.program_output import PerpetualOutputs +from services.perpetual.cairo.position.update_position import ( + NO_SYNTHETIC_DELTA_ASSET_ID, update_position_in_dict) +from services.perpetual.cairo.state.state import CarriedState, carried_state_new +from services.perpetual.cairo.transactions.batch_config import BatchConfig +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.hash import hash2 +from starkware.cairo.common.math import assert_nn_le, assert_not_equal + +struct Transfer: + member base : OrderBase* + member nonce : felt + # sender_public_key is the base's public_key. + member sender_position_id : felt + member receiver_public_key : felt + member receiver_position_id : felt + member amount : felt + member asset_id : felt + member expiration_timestamp : felt +end + +# See the documentation of transfer_hash under exchange/signature_message_hashes.cairo. +# Since there are currently no fees in transfer, max_amount_fee and asset_id_fee are zero. +# +# Assumptions: +# 0 <= nonce < NONCE_UPPER_BOUND +# 0 <= sender_position_id, receiver_position_id, src_fee_position_id < POSITION_ID_UPPER_BOUND +# 0 <= amount, max_amount_fee < AMOUNT_UPPER_BOUND +# 0 <= expiration_timestamp < EXPIRATION_TIMESTAMP_UPPER_BOUND. +func transfer_hash(pedersen_ptr : HashBuiltin*, transfer : Transfer*, condition : felt) -> ( + pedersen_ptr : HashBuiltin*, message): + alloc_locals + static_assert POSITION_ID_UPPER_BOUND == VAULT_ID_UPPER_BOUND + + let (local exchange_transfer : ExchangeTransfer*) = alloc() + assert exchange_transfer.base = transfer.base + assert exchange_transfer.sender_vault_id = transfer.sender_position_id + assert exchange_transfer.receiver_public_key = transfer.receiver_public_key + assert exchange_transfer.receiver_vault_id = transfer.receiver_position_id + assert exchange_transfer.amount = transfer.amount + assert exchange_transfer.asset_id = transfer.asset_id + # The sender is the one that pays the fee. + assert exchange_transfer.src_fee_vault_id = transfer.sender_position_id + assert exchange_transfer.asset_id_fee = 0 + assert exchange_transfer.max_amount_fee = 0 + + let (transfer_hash) = exchange_transfer_hash{pedersen_ptr=pedersen_ptr}( + transfer=exchange_transfer, condition=condition) + return (pedersen_ptr=pedersen_ptr, message=transfer_hash) +end + +func execute_transfer( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, batch_config : BatchConfig*, outputs : PerpetualOutputs*, + tx : Transfer*) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*): + alloc_locals + %{ error_code = ids.PerpetualErrorCode.SAME_POSITION_ID %} + assert_not_equal(tx.sender_position_id, tx.receiver_position_id) + %{ error_code = ids.PerpetualErrorCode.OUT_OF_RANGE_AMOUNT %} + assert_nn_le{range_check_ptr=range_check_ptr}(tx.amount, AMOUNT_UPPER_BOUND - 1) + %{ del error_code %} + local range_check_ptr = range_check_ptr + # expiration_timestamp and nonce will be validated in validate_order_and_update_fulfillment. + # Asset id is in range because we check that it's equal to the collateral asset id. + # Sender/Reciever's position id will be validated by update_position_in_dict. + + local general_config : GeneralConfig* = batch_config.general_config + # Validate that asset is collateral. + %{ error_code = ids.PerpetualErrorCode.INVALID_COLLATERAL_ASSET_ID %} + assert tx.asset_id = general_config.collateral_asset_info.asset_id + %{ del error_code %} + + let (local pedersen_ptr, message_hash) = transfer_hash( + pedersen_ptr=pedersen_ptr, transfer=tx, condition=0) + + let (range_check_ptr, local ecdsa_ptr, + local orders_dict) = validate_order_and_update_fulfillment( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + orders_dict=carried_state.orders_dict, + message_hash=message_hash, + order=tx.base, + min_expiration_timestamp=batch_config.min_expiration_timestamp, + update_amount=tx.amount, + full_amount=tx.amount) + + # Update the sender's position. + let (range_check_ptr, positions_dict, _, _, return_code) = update_position_in_dict( + range_check_ptr=range_check_ptr, + positions_dict=carried_state.positions_dict, + position_id=tx.sender_position_id, + request_public_key=tx.base.public_key, + collateral_delta=-tx.amount, + synthetic_asset_id=NO_SYNTHETIC_DELTA_ASSET_ID, + synthetic_delta=0, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + general_config=general_config) + assert_success(return_code) + + # Update the receiver's position. + let (range_check_ptr, positions_dict, _, _, return_code) = update_position_in_dict( + range_check_ptr=range_check_ptr, + positions_dict=positions_dict, + position_id=tx.receiver_position_id, + request_public_key=tx.receiver_public_key, + collateral_delta=tx.amount, + synthetic_asset_id=NO_SYNTHETIC_DELTA_ASSET_ID, + synthetic_delta=0, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + general_config=general_config) + assert_success(return_code) + + let (carried_state) = carried_state_new( + positions_dict=positions_dict, + orders_dict=orders_dict, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + system_time=carried_state.system_time) + + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) +end diff --git a/src/services/perpetual/cairo/transactions/withdrawal.cairo b/src/services/perpetual/cairo/transactions/withdrawal.cairo new file mode 100644 index 0000000..f9d8d6a --- /dev/null +++ b/src/services/perpetual/cairo/transactions/withdrawal.cairo @@ -0,0 +1,121 @@ +from services.exchange.cairo.order import OrderBase +from services.perpetual.cairo.definitions.constants import ( + AMOUNT_UPPER_BOUND, EXPIRATION_TIMESTAMP_UPPER_BOUND, NONCE_UPPER_BOUND, ORDER_ID_UPPER_BOUND, + POSITION_ID_UPPER_BOUND) +from services.perpetual.cairo.definitions.general_config import GeneralConfig +from services.perpetual.cairo.definitions.perpetual_error_code import assert_success +from services.perpetual.cairo.order.order import validate_order_and_update_fulfillment +from services.perpetual.cairo.output.program_output import ( + Modification, PerpetualOutputs, perpetual_outputs_new) +from services.perpetual.cairo.position.update_position import ( + NO_SYNTHETIC_DELTA_ASSET_ID, update_position_in_dict) +from services.perpetual.cairo.state.state import CarriedState, carried_state_new +from services.perpetual.cairo.transactions.batch_config import BatchConfig +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.hash import hash2 +from starkware.cairo.common.math import assert_nn_le + +struct Withdrawal: + member base : OrderBase* + member position_id : felt + member amount : felt +end + +# withdrawal_hash: +# Computes the hash of withdrawal request. +# +# The hash is defined as h(w1, w2) where h is the starkware pedersen function and w1, w2 are as +# follows: +# w1= asset_id_collateral +# w2= 0x6 (10 bit) || vault_from (64 bit) || nonce (64 bit) || expiration_timestamp (32 bit) +# || 0 (49 bit) +# +# Assumptions: +# 0 <= nonce < NONCE_UPPER_BOUND +# 0 <= position_id < POSITION_ID_UPPER_BOUND +# 0 <= expiration_timestamp < EXPIRATION_TIMESTAMP_UPPER_BOUND +# 0 <= amount < AMOUNT_UPPER_BOUND. +func withdrawal_hash( + pedersen_ptr : HashBuiltin*, withdrawal : Withdrawal*, asset_id_collateral) -> ( + pedersen_ptr : HashBuiltin*, message): + const WITHDRAWAL = 6 + let packed_message = WITHDRAWAL + let packed_message = packed_message * POSITION_ID_UPPER_BOUND + withdrawal.position_id + let packed_message = packed_message * NONCE_UPPER_BOUND + withdrawal.base.nonce + let packed_message = packed_message * AMOUNT_UPPER_BOUND + withdrawal.amount + let expiration_timestamp = withdrawal.base.expiration_timestamp + let packed_message = packed_message * EXPIRATION_TIMESTAMP_UPPER_BOUND + expiration_timestamp + let packed_message = packed_message * %[2**49%] # Padding. + + let (message) = hash2{hash_ptr=pedersen_ptr}(x=asset_id_collateral, y=packed_message) + return (pedersen_ptr=pedersen_ptr, message=message) +end + +func execute_withdrawal( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, batch_config : BatchConfig*, outputs : PerpetualOutputs*, + tx : Withdrawal*) -> ( + pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*, + carried_state : CarriedState*, outputs : PerpetualOutputs*): + alloc_locals + local general_config : GeneralConfig* = batch_config.general_config + + # The amount, nonce and expiration_timestamp are range checked in + # validate_order_and_update_fulfillment. + # By using update_position_in_dict with tx.position_id we check that + # 0 <= tx.position_id < 2**POSITION_TREE_HEIGHT = POSITION_ID_UPPER_BOUND. + let (local pedersen_ptr, message_hash) = withdrawal_hash( + pedersen_ptr=pedersen_ptr, + withdrawal=tx, + asset_id_collateral=general_config.collateral_asset_info.asset_id) + + let (range_check_ptr, local ecdsa_ptr, + local orders_dict) = validate_order_and_update_fulfillment( + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + orders_dict=carried_state.orders_dict, + message_hash=message_hash, + order=tx.base, + min_expiration_timestamp=batch_config.min_expiration_timestamp, + update_amount=tx.amount, + full_amount=tx.amount) + + let (range_check_ptr, positions_dict, _, _, return_code) = update_position_in_dict( + range_check_ptr=range_check_ptr, + positions_dict=carried_state.positions_dict, + position_id=tx.position_id, + request_public_key=tx.base.public_key, + collateral_delta=-tx.amount, + synthetic_asset_id=NO_SYNTHETIC_DELTA_ASSET_ID, + synthetic_delta=0, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + general_config=general_config) + assert_success(return_code) + + let (carried_state) = carried_state_new( + positions_dict=positions_dict, + orders_dict=orders_dict, + global_funding_indices=carried_state.global_funding_indices, + oracle_prices=carried_state.oracle_prices, + system_time=carried_state.system_time) + + # Write to output. + tempvar modification : Modification* = outputs.modifications_ptr + assert modification.public_key = tx.base.public_key + assert modification.position_id = tx.position_id + assert modification.biased_delta = AMOUNT_UPPER_BOUND - tx.amount + let (outputs : PerpetualOutputs*) = perpetual_outputs_new( + modifications_ptr=modification + Modification.SIZE, + forced_actions_ptr=outputs.forced_actions_ptr, + conditions_ptr=outputs.conditions_ptr, + funding_indices_table_ptr=outputs.funding_indices_table_ptr) + + return ( + pedersen_ptr=pedersen_ptr, + range_check_ptr=range_check_ptr, + ecdsa_ptr=ecdsa_ptr, + carried_state=carried_state, + outputs=outputs) +end diff --git a/src/services/perpetual/public/CMakeLists.txt b/src/services/perpetual/public/CMakeLists.txt index 180671b..820b6ec 100644 --- a/src/services/perpetual/public/CMakeLists.txt +++ b/src/services/perpetual/public/CMakeLists.txt @@ -3,7 +3,7 @@ python_lib(perpetual_public_lib FILES perpetual_messages.py - public_generate_perpetual_config_hash.py + generate_perpetual_config_hash.py stark_cli.py LIBS diff --git a/src/services/perpetual/public/public_generate_perpetual_config_hash.py b/src/services/perpetual/public/generate_perpetual_config_hash.py similarity index 90% rename from src/services/perpetual/public/public_generate_perpetual_config_hash.py rename to src/services/perpetual/public/generate_perpetual_config_hash.py index eebe38d..ca99d94 100755 --- a/src/services/perpetual/public/public_generate_perpetual_config_hash.py +++ b/src/services/perpetual/public/generate_perpetual_config_hash.py @@ -28,13 +28,17 @@ import yaml +from services.perpetual.public.definitions.constants import ASSET_ID_UPPER_BOUND from starkware.crypto.signature.fast_pedersen_hash import pedersen_hash_func CONFIG_FILE_NAME = 'production_general_config.yml' HASH_BYTES = 32 +ASSET_ID_BYTES = 15 +assert 2 ** (ASSET_ID_BYTES * 8) == ASSET_ID_UPPER_BOUND -def str2int(val): + +def str2int(val: str) -> int: """ Converts a decimal or hex string into an int. Also accepts an int and returns it unchanged. @@ -48,13 +52,23 @@ def str2int(val): return int(val, 10) -def bytes2str(val): +def bytes2str(val: bytes) -> str: """ Converts a bytes into a hex string. """ return f'0x{val.hex()}' +def pad_hex_string(val: str, bytes_len: int) -> str: + """ + Pads a hex string with leading zeros to match a length of bytes_len. + """ + assert val[:2] == '0x' + val_nibbles_len = (len(val) - 2) + assert val_nibbles_len <= 2 * bytes_len + return f'0x{"0" * (2 * bytes_len - val_nibbles_len)}{val[2:]}' + + def calculate_general_config_hash(config: dict) -> bytes: """ Calculates the hash of the general config without the synthetic assets info. @@ -145,7 +159,8 @@ def generate_config_hashes(config: dict) -> str: for asset_id in config['synthetic_assets_info'].keys(): config_hash_bytes = calculate_asset_hash(config=config, asset_id=asset_id) config_hash_hex = bytes2str(config_hash_bytes) - output += f'asset_id: {int(asset_id, 16)}, config_hash: {config_hash_hex}\n' + asset_id_padded = pad_hex_string(asset_id, ASSET_ID_BYTES) + output += f'asset_id: {asset_id_padded}, config_hash: {config_hash_hex}\n' output += '\n' return output diff --git a/src/starkware/CMakeLists.txt b/src/starkware/CMakeLists.txt index 88a6a9e..3bb04c3 100644 --- a/src/starkware/CMakeLists.txt +++ b/src/starkware/CMakeLists.txt @@ -1 +1,3 @@ add_subdirectory(crypto) +add_subdirectory(cairo) +add_subdirectory(python) diff --git a/src/starkware/cairo/CMakeLists.txt b/src/starkware/cairo/CMakeLists.txt new file mode 100644 index 0000000..705fd45 --- /dev/null +++ b/src/starkware/cairo/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(bootloader) +add_subdirectory(common) +add_subdirectory(lang) +add_subdirectory(sharp) diff --git a/src/starkware/cairo/bootloader/CMakeLists.txt b/src/starkware/cairo/bootloader/CMakeLists.txt new file mode 100644 index 0000000..b37e8a3 --- /dev/null +++ b/src/starkware/cairo/bootloader/CMakeLists.txt @@ -0,0 +1,61 @@ +python_lib(cairo_hash_program_lib + PREFIX starkware/cairo/bootloader + + FILES + hash_program.py + + LIBS + cairo_common_lib + cairo_compile_lib + cairo_version_lib + cairo_vm_crypto_lib +) + +python_venv(cairo_hash_program_venv + PYTHON python3.7 + LIBS + cairo_hash_program_lib +) + +python_exe(cairo_hash_program_exe + VENV cairo_hash_program_venv + MODULE starkware.cairo.bootloader.hash_program +) + +full_python_test(cairo_hash_program_test + PREFIX starkware/cairo/bootloader + PYTHON python3.7 + TESTED_MODULES starkware/cairo/bootloader + + FILES + hash_program_test.py + + LIBS + cairo_hash_program_lib + pip_pytest +) + +python_lib(cairo_bootloader_fact_topology_lib + PREFIX starkware/cairo/bootloader + FILES + fact_topology.py + + LIBS + pip_marshmallow + pip_marshmallow_dataclass +) + +python_lib(cairo_bootloader_generate_fact_lib + PREFIX starkware/cairo/bootloader + FILES + compute_fact.py + generate_fact.py + + LIBS + cairo_bootloader_fact_topology_lib + cairo_hash_program_lib + cairo_relocatable + cairo_vm_lib + pip_eth_hash + pip_pycryptodome +) diff --git a/src/starkware/cairo/bootloader/compute_fact.py b/src/starkware/cairo/bootloader/compute_fact.py new file mode 100644 index 0000000..a64a14e --- /dev/null +++ b/src/starkware/cairo/bootloader/compute_fact.py @@ -0,0 +1,81 @@ +import binascii +import dataclasses +from typing import List + +from eth_hash.auto import keccak + +from starkware.cairo.bootloader.fact_topology import FactTopology + + +def keccak_ints(values: List[int]) -> str: + """ + Computes the keccak of a list of ints. + This function is compatible with + Web3.solidityKeccak(['uint256[]'], [values]).hex() + """ + return '0x' + binascii.hexlify( + keccak(b''.join(value.to_bytes(32, 'big') for value in values))).decode('ascii') + + +def generate_program_fact( + program_hash: int, program_output: List[int], fact_topology: FactTopology) -> str: + """ + Generates the program fact of the Cairo program with program_hash and program_output. + See GpsOutputParser.sol for more information on the way the fact is computed. + """ + return keccak_ints([ + program_hash, + generate_output_root(program_output=program_output, fact_topology=fact_topology).node_hash + ]) + + +@dataclasses.dataclass +class FactNode: + node_hash: int + end_offset: int + size: int + children: List['FactNode'] + + +def generate_output_root( + program_output: List[int], fact_topology: FactTopology) -> FactNode: + """ + Generates the root of the output Merkle tree for the program fact computation. + See GpsOutputParser.sol for more information on the way the fact is computed. + """ + # Create a copy of page_sizes. + page_sizes = list(fact_topology.page_sizes) + tree_structure = fact_topology.tree_structure + offset = 0 + node_stack: List[FactNode] = [] + for n_pages, n_nodes in zip(tree_structure[::2], tree_structure[1::2]): + # Push n_pages to the stack. + assert 0 <= n_pages <= len(page_sizes), 'Invalid tree structure: n_pages is out of range.' + for _ in range(n_pages): + page_size = page_sizes.pop(0) + page_hash = int(keccak_ints(program_output[offset:offset + page_size]), 16) + + offset += page_size + + node_stack.append(FactNode( + node_hash=page_hash, end_offset=offset, size=page_size, children=[])) + + assert 0 <= n_nodes <= len(node_stack), 'Invalid tree structure: n_nodes is out of range.' + if n_nodes > 0: + # Create a parent node to the last n_nodes in the head of the stack. + node_stack, child_nodes = node_stack[:-n_nodes], node_stack[-n_nodes:] + # Create an alternating list of hashes and end offsets. + node_data = [val for node in child_nodes for val in [node.node_hash, node.end_offset]] + node_stack.append(FactNode( + node_hash=1 + int(keccak_ints(node_data), 16), + end_offset=child_nodes[-1].end_offset, + size=sum(node.size for node in child_nodes), + children=child_nodes)) + + # Make sure there is one node in the stack (hash and end). + assert len(node_stack) == 1, 'Invalid tree structure: stack contains more than one node.' + # Make sure all pages were processed. + assert len(page_sizes) == 0, 'Invalid tree structure: not all pages were processed.' + assert offset == node_stack[0].end_offset == len(program_output) + + return node_stack[0] diff --git a/src/starkware/cairo/bootloader/fact_topology.py b/src/starkware/cairo/bootloader/fact_topology.py new file mode 100644 index 0000000..1f1da92 --- /dev/null +++ b/src/starkware/cairo/bootloader/fact_topology.py @@ -0,0 +1,31 @@ +import dataclasses +import json +from typing import ClassVar, List, Type + +import marshmallow +import marshmallow_dataclass + +GPS_FACT_TOPOLOGY = 'gps_fact_topology' + + +@dataclasses.dataclass +class FactTopology: + tree_structure: List[int] + page_sizes: List[int] + + +@marshmallow_dataclass.dataclass +class FactTopologiesFile: + fact_topologies: List[FactTopology] + Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema + + +def load_fact_topologies(path) -> List[FactTopology]: + return FactTopologiesFile.Schema().load(json.load(open(path))).fact_topologies + + +@dataclasses.dataclass +class FactInfo: + program_output: List[int] + fact_topology: FactTopology + fact: str diff --git a/src/starkware/cairo/bootloader/generate_fact.py b/src/starkware/cairo/bootloader/generate_fact.py new file mode 100644 index 0000000..97d93f9 --- /dev/null +++ b/src/starkware/cairo/bootloader/generate_fact.py @@ -0,0 +1,108 @@ +from typing import Any, Dict, List, Optional + +from starkware.cairo.bootloader.compute_fact import generate_program_fact +from starkware.cairo.bootloader.fact_topology import GPS_FACT_TOPOLOGY, FactInfo, FactTopology +from starkware.cairo.bootloader.hash_program import compute_program_hash_chain +from starkware.cairo.lang.vm.cairo_pie import CairoPie +from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue + + +def get_program_output(cairo_pie: CairoPie) -> List[int]: + """ + Returns the program output. + """ + assert 'output' in cairo_pie.metadata.builtin_segments, 'The output builtin must be used.' + output = cairo_pie.metadata.builtin_segments['output'] + + def verify_int(x: MaybeRelocatable) -> int: + assert isinstance(x, int), \ + f'Expected program output to contain absolute values, found: {x}.' + return x + + return [ + verify_int(cairo_pie.memory[RelocatableValue(segment_index=output.index, offset=i)]) + for i in range(output.size)] + + +def get_cairo_pie_fact_info(cairo_pie: CairoPie, program_hash: Optional[int] = None) -> FactInfo: + """ + Generates the fact of the Cairo program of cairo_pie. Returns the cairo-pie fact info. + """ + program_output = get_program_output(cairo_pie=cairo_pie) + fact_topology = get_fact_topology_from_additional_data( + output_size=len(program_output), + output_builtin_additional_data=cairo_pie.additional_data['output_builtin']) + if program_hash is None: + program_hash = get_program_hash(cairo_pie) + fact = generate_program_fact(program_hash, program_output, fact_topology=fact_topology) + return FactInfo(program_output=program_output, fact_topology=fact_topology, fact=fact) + + +def get_program_hash(cairo_pie: CairoPie) -> int: + return compute_program_hash_chain(cairo_pie.metadata.program) + + +def get_page_sizes_from_page_dict(output_size: int, pages: dict) -> List[int]: + """ + Returns the sizes of the program output pages, given the pages dictionary that appears + in the additional attributes of the output builtin. + """ + # Make sure the pages are adjacent to each other. + + # The first page id is expected to be 1. + expected_page_id = 1 + # We don't expect anything on its start value. + expected_page_start = None + # The size of page 0 is output_size if there are no other pages, or the start of page 1 + # otherwise. + page0_size = output_size + + for page_id_str, (page_start, page_size) in sorted(pages.items()): + page_id = int(page_id_str) + assert page_id == expected_page_id, f'Expected page id {expected_page_id}, found {page_id}.' + if page_id == 1: + assert isinstance(page_start, int) and 0 < page_start <= output_size, \ + f'Invalid page start {page_start}.' + page0_size = page_start + else: + assert page_start == expected_page_start, \ + f'Expected page start {expected_page_start}, found {page_start}.' + + assert isinstance(page_size, int) and 0 < page_size <= output_size, \ + f'Invalid page size {page_size}.' + + expected_page_start = page_start + page_size + expected_page_id += 1 + + if len(pages) > 0: + assert expected_page_start == output_size, 'Pages must cover the entire program output.' + + return [page0_size] + [page_size for _, (_, page_size) in sorted(pages.items())] + + +def get_fact_topology_from_additional_data( + output_size: int, output_builtin_additional_data: Dict[str, Any]) -> FactTopology: + """ + Returns the fact topology from the additional data of the output builtin. + """ + pages = output_builtin_additional_data['pages'] + attributes = output_builtin_additional_data['attributes'] + + # If the GPS_FACT_TOPOLOGY attribute is present, use it. Otherwise, the task is expected to + # use exactly one page (page 0). + if GPS_FACT_TOPOLOGY in attributes: + tree_structure = attributes[GPS_FACT_TOPOLOGY] + assert isinstance(tree_structure, list) and \ + len(tree_structure) % 2 == 0 and \ + 0 < len(tree_structure) <= 10 and \ + all(isinstance(x, int) and 0 <= x < 2**30 for x in tree_structure), \ + f"Invalid tree structure specified in the '{GPS_FACT_TOPOLOGY}' attribute." + else: + assert len(pages) == 0, \ + f"Additional pages cannot be used since the '{GPS_FACT_TOPOLOGY}' attribute is not " \ + 'specified.' + tree_structure = [1, 0] + + return FactTopology( + tree_structure=tree_structure, + page_sizes=get_page_sizes_from_page_dict(output_size, pages)) diff --git a/src/starkware/cairo/bootloader/hash_program.py b/src/starkware/cairo/bootloader/hash_program.py new file mode 100644 index 0000000..24e1bfc --- /dev/null +++ b/src/starkware/cairo/bootloader/hash_program.py @@ -0,0 +1,40 @@ +import argparse +import json + +from starkware.cairo.common.hash_chain import compute_hash_chain +from starkware.cairo.lang.compiler.program import Program, ProgramBase +from starkware.cairo.lang.version import __version__ +from starkware.cairo.lang.vm.crypto import get_crypto_lib_context_manager + + +def compute_program_hash_chain(program: ProgramBase, bootloader_version=0): + """ + Computes a hash chain over a program, including the length of the data chain. + """ + builtin_list = [int.from_bytes(builtin.encode('ascii'), 'big') for builtin in program.builtins] + # The program header below is missing the data length, which is later added to the data_chain. + program_header = [bootloader_version, program.main, len(program.builtins)] + builtin_list + data_chain = program_header + program.data + + return compute_hash_chain([len(data_chain)] + data_chain) + + +def main(): + parser = argparse.ArgumentParser( + description='A tool to compute the hash of a cairo program') + parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}') + parser.add_argument( + '--program', type=argparse.FileType('r'), required=True, + help='The name of the program json file.') + parser.add_argument( + '--flavor', type=str, default='Release', choices=['Debug', 'Release', 'RelWithDebInfo'], + help='Build flavor') + args = parser.parse_args() + + with get_crypto_lib_context_manager(args.flavor): + program = Program.Schema().load(json.load(args.program)) + print(hex(compute_program_hash_chain(program))) + + +if __name__ == '__main__': + main() diff --git a/src/starkware/cairo/bootloader/hash_program_test.py b/src/starkware/cairo/bootloader/hash_program_test.py new file mode 100644 index 0000000..7201fd1 --- /dev/null +++ b/src/starkware/cairo/bootloader/hash_program_test.py @@ -0,0 +1,8 @@ +from starkware.cairo.common.hash_chain import compute_hash_chain +from starkware.cairo.lang.vm.crypto import pedersen_hash + + +def test_compute_hash_chain(): + data = [1, 2, 3] + res = compute_hash_chain(data) + assert res == pedersen_hash(1, pedersen_hash(2, 3)) diff --git a/src/starkware/cairo/common/CMakeLists.txt b/src/starkware/cairo/common/CMakeLists.txt new file mode 100644 index 0000000..ad6877a --- /dev/null +++ b/src/starkware/cairo/common/CMakeLists.txt @@ -0,0 +1,32 @@ +python_lib(cairo_common_lib + PREFIX starkware/cairo/common + FILES + alloc.cairo + cairo_builtins.cairo + dict.cairo + dict_access.cairo + dict.py + find_element.cairo + hash_chain.cairo + hash_chain.py + hash_state.cairo + hash.cairo + invoke.cairo + math_utils.py + math.cairo + memcpy.cairo + merkle_multi_update.cairo + merkle_update.cairo + registers.cairo + serialize.cairo + signature.cairo + small_merkle_tree.cairo + small_merkle_tree.py + squash_dict.cairo + ${CAIRO_COMMON_LIB_ADDITIONAL_FILES} + + LIBS + cairo_vm_crypto_lib + starkware_merkle_tree_lib + ${CAIRO_COMMON_LIB_ADDITIONAL_LIBS} +) diff --git a/src/starkware/cairo/common/alloc.cairo b/src/starkware/cairo/common/alloc.cairo new file mode 100644 index 0000000..320e3dd --- /dev/null +++ b/src/starkware/cairo/common/alloc.cairo @@ -0,0 +1,6 @@ +# Allocates a new memory segment. +func alloc() -> (ptr : felt*): + %{ memory[ap] = segments.add() %} + ap += 1 + return (ptr=cast([ap - 1], felt*)) +end diff --git a/src/starkware/cairo/common/cairo_builtins.cairo b/src/starkware/cairo/common/cairo_builtins.cairo new file mode 100644 index 0000000..fabceb4 --- /dev/null +++ b/src/starkware/cairo/common/cairo_builtins.cairo @@ -0,0 +1,19 @@ +# A representation of a HashBuiltin struct, specifying the hash builtin memory structure. +struct HashBuiltin: + member x : felt + member y : felt + member result : felt +end + +# A representation of a SignatureBuiltin struct, specifying the signature builtin memory structure. +struct SignatureBuiltin: + member pub_key : felt + member message : felt +end + +# A representation of a CheckpointsBuiltin struct, specifying the checkpoints builtin memory +# structure. +struct CheckpointsBuiltin: + member required_pc : felt + member required_fp : felt +end diff --git a/src/starkware/cairo/common/dict.cairo b/src/starkware/cairo/common/dict.cairo new file mode 100644 index 0000000..679e366 --- /dev/null +++ b/src/starkware/cairo/common/dict.cairo @@ -0,0 +1,108 @@ +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.squash_dict import squash_dict + +# Creates a new dict. +func dict_new() -> (res : DictAccess*): + %{ + if '__dict_manager' not in globals(): + from starkware.cairo.common.dict import DictManager + __dict_manager = DictManager() + + memory[ap] = __dict_manager.new_dict(segments, initial_dict) + del initial_dict + %} + ap += 1 + return (res=cast([ap - 1], DictAccess*)) +end + +# Reads a value from the dictionary and returns the result. +func dict_read{dict_ptr : DictAccess*}(key : felt) -> (value : felt): + alloc_locals + local value + %{ + dict_tracker = __dict_manager.get_tracker(ids.dict_ptr) + dict_tracker.current_ptr += ids.DictAccess.SIZE + ids.value = dict_tracker.data[ids.key] + %} + assert dict_ptr.key = key + assert dict_ptr.prev_value = value + assert dict_ptr.new_value = value + let dict_ptr = dict_ptr + DictAccess.SIZE + return (value=value) +end + +# Writes a value to the dictionary, overriding the existing value. +func dict_write{dict_ptr : DictAccess*}(key : felt, new_value : felt): + %{ + dict_tracker = __dict_manager.get_tracker(ids.dict_ptr) + dict_tracker.current_ptr += ids.DictAccess.SIZE + ids.dict_ptr.prev_value = dict_tracker.data[ids.key] + dict_tracker.data[ids.key] = ids.new_value + %} + assert dict_ptr.key = key + assert dict_ptr.new_value = new_value + let dict_ptr = dict_ptr + DictAccess.SIZE + return () +end + +# Updates a value in a dict. prev_value must be specified. A standalone read with no write should be +# performed by writing the same value. +# It is possible to get prev_value from __dict_manager using the hint: +# %{ ids.val = __dict_manager.get_dict(ids.dict_ptr)[ids.key] %} +func dict_update{dict_ptr : DictAccess*}(key : felt, prev_value : felt, new_value : felt): + %{ + # Verify dict pointer and prev value. + dict_tracker = __dict_manager.get_tracker(ids.dict_ptr) + current_value = dict_tracker.data[ids.key] + assert current_value == ids.prev_value, \ + f'Wrong previous value in dict. Got {ids.prev_value}, expected {current_value}.' + + # Update value. + dict_tracker.data[ids.key] = ids.new_value + dict_tracker.current_ptr += ids.DictAccess.SIZE + %} + dict_ptr.key = key + dict_ptr.prev_value = prev_value + dict_ptr.new_value = new_value + let dict_ptr = dict_ptr + DictAccess.SIZE + return () +end + +# Returns a new dictionary with one DictAccess instance per key +# (value before and value after) which summarizes all the changes to that key. +# +# Example: +# Input: {(key1, 0, 2), (key1, 2, 7), (key2, 4, 1), (key1, 7, 5), (key2, 1, 2)} +# Output: {(key1, 0, 5), (key2, 4, 2)} +# +# This is a wrapper of squash_dict for dictionaries created by dict_new(). +func dict_squash{range_check_ptr}( + dict_accesses_start : DictAccess*, dict_accesses_end : DictAccess*) -> ( + squashed_dict_start : DictAccess*, squashed_dict_end : DictAccess*): + alloc_locals + + %{ + # Prepare arguments for dict_new. In particular, the same dictionary values should be copied + # to the new (squashed) dictionary. + vm_enter_scope({ + # Make __dict_manager accessible. + '__dict_manager': __dict_manager, + # Create a copy of the dict, in case it changes in the future. + 'initial_dict': dict(__dict_manager.get_dict(ids.dict_accesses_end)), + }) + %} + let (local squashed_dict_start) = dict_new() + %{ vm_exit_scope() %} + + let (squashed_dict_end) = squash_dict( + dict_accesses=dict_accesses_start, + dict_accesses_end=dict_accesses_end, + squashed_dict=squashed_dict_start) + + %{ + # Update the DictTracker's current_ptr to point to the end of the squashed dict. + __dict_manager.get_tracker(ids.squashed_dict_start).current_ptr = \ + ids.squashed_dict_end.address_ + %} + return (squashed_dict_start=squashed_dict_start, squashed_dict_end=squashed_dict_end) +end diff --git a/src/starkware/cairo/common/dict.py b/src/starkware/cairo/common/dict.py new file mode 100644 index 0000000..fcd07b8 --- /dev/null +++ b/src/starkware/cairo/common/dict.py @@ -0,0 +1,60 @@ +import dataclasses +from typing import Dict + +from starkware.cairo.lang.vm.relocatable import RelocatableValue +from starkware.cairo.lang.vm.vm_consts import VmConstsReference + + +@dataclasses.dataclass +class DictTracker: + """ + Tracks the python dict associated with a Cairo dict. + """ + # Python dict. + data: dict + # Pointer to the first unused position in the dict segment. + current_ptr: RelocatableValue + + +class DictManager: + """ + Manages dictionaries in a Cairo program. + Uses the segment index to associate the corresponding python dict with the Cairo dict. + """ + + def __init__(self): + # Mapping from segment index to the corresponding DictTracker of the Cairo dict. + self.trackers: Dict[int, DictTracker] = {} + + def new_dict(self, segments, initial_dict): + """ + Creates a new Cairo dictionary. The values of initial_dict can be integers, tuples or + lists. See MemorySegments.gen_arg(). + """ + base = segments.add() + assert base.segment_index not in self.trackers + self.trackers[base.segment_index] = DictTracker( + data={ + key: segments.gen_arg(value) for key, value in initial_dict.items()}, + current_ptr=base, + ) + return base + + def get_tracker(self, dict_ptr): + """ + Gets a dict tracker given the dict_ptr. + """ + if isinstance(dict_ptr, VmConstsReference): + dict_ptr = dict_ptr.address_ + dict_tracker = self.trackers.get(dict_ptr.segment_index) + if dict_tracker is None: + raise ValueError(f'Dictionary pointer {dict_ptr} was not created using dict_new().') + assert dict_tracker.current_ptr == dict_ptr, 'Wrong dict pointer supplied. ' \ + f'Got {dict_ptr}, expected {dict_tracker.current_ptr}.' + return dict_tracker + + def get_dict(self, dict_ptr) -> dict: + """ + Gets the python dict that corresponds to dict_ptr. + """ + return self.get_tracker(dict_ptr).data diff --git a/src/starkware/cairo/common/dict_access.cairo b/src/starkware/cairo/common/dict_access.cairo new file mode 100644 index 0000000..92d4178 --- /dev/null +++ b/src/starkware/cairo/common/dict_access.cairo @@ -0,0 +1,5 @@ +struct DictAccess: + member key : felt + member prev_value : felt + member new_value : felt +end diff --git a/src/starkware/cairo/common/find_element.cairo b/src/starkware/cairo/common/find_element.cairo new file mode 100644 index 0000000..fe2571a --- /dev/null +++ b/src/starkware/cairo/common/find_element.cairo @@ -0,0 +1,100 @@ +from starkware.cairo.common.math import assert_le, assert_nn_le + +# Finds an element in the array whose first field is key and returns a pointer +# to this element. +# Since cairo is non-deterministic this is an O(1) operation. +# Note however that if the array has multiple elements with said key the function may return any +# of those elements. +# +# Arguments: +# array_ptr - pointer to an array. +# elm_size - size of an element in the array. +# n_elms - number of element in the array. +# key - key to look for. +# +# Implicit arguments: +# range_check_ptr - range check builtin pointer. +# +# Returns: +# elm_ptr - pointer to an element in the array satisfying [ptr] = key. +# +# Optional hint variables: +# __find_element_index - the index that should be returned. If not specified, the function will +# search for it. +func find_element{range_check_ptr}(array_ptr : felt*, elm_size, n_elms, key) -> (elm_ptr : felt*): + alloc_locals + local index + %{ + if '__find_element_index' in globals(): + ids.index = __find_element_index + found_key = memory[ids.array_ptr + ids.elm_size * __find_element_index] + assert found_key == ids.key, \ + f'Invalid index found in __find_element_index. index: {__find_element_index}, ' \ + f'expected key {ids.key}, found key: {found_key}.' + # Delete __find_element_index to make sure it's not used for the next calls. + del __find_element_index + else: + for i in range(ids.n_elms): + if memory[ids.array_ptr + ids.elm_size * i] == ids.key: + ids.index = i + break + else: + raise ValueError(f'Key {ids.key} not found.') + %} + + assert_nn_le(a=index, b=n_elms - 1) + tempvar elm_ptr = array_ptr + elm_size * index + assert [elm_ptr] = key + return (elm_ptr=elm_ptr) +end + +# Given an array sorted by its first field, returns the pointer to the first element in the array +# whose first field is at least key. If no such item exists, returns a pointer to the end of the +# array. +# Prover assumption: all the keys (the first field in each item) are in [0, RANGE_CHECK_BOUND). +func search_sorted_lower{range_check_ptr}(array_ptr : felt*, elm_size, n_elms, key) -> ( + elm_ptr : felt*): + alloc_locals + local index + %{ + for i in range(ids.n_elms): + if memory[ids.array_ptr + ids.elm_size * i] >= ids.key: + ids.index = i + break + else: + ids.index = ids.n_elms + %} + + assert_nn_le(a=index, b=n_elms) + local elm_ptr : felt* = array_ptr + elm_size * index + + if index != n_elms: + assert_le(a=key, b=[elm_ptr]) + else: + tempvar range_check_ptr = range_check_ptr + end + + if index != 0: + assert_le(a=[elm_ptr - elm_size] + 1, b=key) + end + + return (elm_ptr=elm_ptr) +end + +# Given an array sorted by its first field, returns the pointer to the first element in the array +# whose first field is exactly key. If no such item exists, returns an undefined pointer, +# and success=0. +# Prover assumption: all the keys (the first field in each item) are in [0, RANGE_CHECK_BOUND). +func search_sorted{range_check_ptr}(array_ptr : felt*, elm_size, n_elms, key) -> ( + elm_ptr : felt*, success): + let (elm_ptr) = search_sorted_lower( + array_ptr=array_ptr, elm_size=elm_size, n_elms=n_elms, key=key) + tempvar array_end = array_ptr + elm_size * n_elms + if elm_ptr == array_end: + return (elm_ptr=array_ptr, success=0) + end + if [elm_ptr] != key: + return (elm_ptr=array_ptr, success=0) + end + return (elm_ptr=elm_ptr, success=1) +end diff --git a/src/starkware/cairo/common/hash.cairo b/src/starkware/cairo/common/hash.cairo new file mode 100644 index 0000000..abdd970 --- /dev/null +++ b/src/starkware/cairo/common/hash.cairo @@ -0,0 +1,18 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin + +# Computes the hash of two given field elements. +# The hash function is defined by the hash_ptr used. +# +# Arguments: +# hash_ptr - the hash builtin pointer. +# x, y - the two field elements to be hashed, in this order. +# +# Returns: +# result - the field element result of the hash. +func hash2{hash_ptr : HashBuiltin*}(x, y) -> (result): + hash_ptr.x = x + hash_ptr.y = y + let result = hash_ptr.result + let hash_ptr = hash_ptr + HashBuiltin.SIZE + return (result=result) +end diff --git a/src/starkware/cairo/common/hash_chain.cairo b/src/starkware/cairo/common/hash_chain.cairo new file mode 100644 index 0000000..9078666 --- /dev/null +++ b/src/starkware/cairo/common/hash_chain.cairo @@ -0,0 +1,50 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin + +# Computes a hash chain of a sequence whose length is given at [data_ptr] and the data starts at +# data_ptr + 1. The hash is calculated backwards (from the highest memory address to the lowest). +# For example, for the 3-element sequence [x, y, z] the hash is: +# h(3, h(x, h(y, z))) +# If data_length = 0, the function does not return (takes more than field prime steps). +func hash_chain{hash_ptr : HashBuiltin*}(data_ptr : felt*) -> (hash : felt): + struct LoopLocals: + member data_ptr : felt* + member hash_ptr : HashBuiltin* + member cur_hash : felt + end + + let data_length = ap + [data_length] = [data_ptr]; ap++ + let loop_frame = cast(ap, LoopLocals*) + + # Prepare the loop_frame for the first iteration of the hash_loop. + loop_frame.data_ptr = data_ptr + [data_length]; ap++ + loop_frame.hash_ptr = hash_ptr; ap++ + loop_frame.cur_hash = [loop_frame.data_ptr]; ap++ + + hash_loop: + let curr_frame = cast(ap - LoopLocals.SIZE, LoopLocals*) + let current_hash : HashBuiltin* = curr_frame.hash_ptr + + let new_data_ptr = curr_frame.data_ptr - 1 + let new_data = ap + [new_data] = [new_data_ptr]; ap++ + + let n_elements_to_hash = ap + # Assign current_hash inputs and allocate space for n_elements_to_hash. + [new_data] = current_hash.x; ap++ + curr_frame.cur_hash = current_hash.y + + # Set the frame for the next loop iteration (going backwards). + let next_frame = cast(ap, LoopLocals*) + next_frame.data_ptr = new_data_ptr; ap++ + next_frame.hash_ptr = curr_frame.hash_ptr + HashBuiltin.SIZE; ap++ + next_frame.cur_hash = current_hash.result; ap++ + + # Update n_elements_to_hash and loop accordingly. Note that the hash is calculated backwards. + [n_elements_to_hash] = next_frame.data_ptr - data_ptr + jmp hash_loop if [n_elements_to_hash] != 0 + + # Set the hash_ptr implicit argument and return the result. + let hash_ptr = next_frame.hash_ptr + return (hash=next_frame.cur_hash) +end diff --git a/src/starkware/cairo/common/hash_chain.py b/src/starkware/cairo/common/hash_chain.py new file mode 100644 index 0000000..33a1a13 --- /dev/null +++ b/src/starkware/cairo/common/hash_chain.py @@ -0,0 +1,12 @@ +from functools import reduce + +from starkware.cairo.lang.vm.crypto import pedersen_hash + + +def compute_hash_chain(data, hash_func=pedersen_hash): + """ + Computes a hash chain over the data, in the following order: + h(data[0], h(data[1], h(..., h(data[n-2], data[n-1])))). + """ + + return reduce(lambda x, y: hash_func(y, x), data[::-1]) diff --git a/src/starkware/cairo/common/hash_state.cairo b/src/starkware/cairo/common/hash_state.cairo new file mode 100644 index 0000000..3500fde --- /dev/null +++ b/src/starkware/cairo/common/hash_state.cairo @@ -0,0 +1,101 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.hash import hash2 +from starkware.cairo.common.registers import get_fp_and_pc + +# Stores the hash of a sequence of items. New items can be added to the hash state using hash_update +# and hash_update_single. The final hash of the entire sequence, including the sequence length, can +# be extracted using hash_finalize. +# For example, the hash of the sequence (x, y, z) is h(h(h(h(0, x), y), z), 3). +# In particular, the hash of zero items is h(0, 0). +struct HashState: + member current_hash : felt + member n_words : felt +end + +# Initializes a new HashState with no items. +func hash_init() -> (hash_state_ptr : HashState*): + alloc_locals + let (__fp__, _) = get_fp_and_pc() + local hash_state : HashState + hash_state.current_hash = 0 + hash_state.n_words = 0 + return (hash_state_ptr=&hash_state) +end + +# A helper function for 'hash_update', see its documentation. +# Computes the hash of an array of items, not including its length. +func hash_update_inner{hash_ptr : HashBuiltin*}( + data_ptr : felt*, data_length : felt, hash : felt) -> (hash : felt): + if data_length == 0: + return (hash=hash) + end + + alloc_locals + local data_last_ptr : felt* = data_ptr + data_length - 1 + struct LoopLocals: + member data_ptr : felt* + member hash_ptr : HashBuiltin* + member cur_hash : felt + end + + # Set up first iteration locals. + let first_locals : LoopLocals* = cast(ap, LoopLocals*) + first_locals.data_ptr = data_ptr; ap++ + first_locals.hash_ptr = hash_ptr; ap++ + first_locals.cur_hash = hash; ap++ + + # Do{. + hash_loop: + let prev_locals : LoopLocals* = cast(ap - LoopLocals.SIZE, LoopLocals*) + tempvar n_remaining_elements = data_last_ptr - prev_locals.data_ptr + + # Compute hash(cur_hash, [data_ptr]). + prev_locals.hash_ptr.x = prev_locals.cur_hash + assert prev_locals.hash_ptr.y = [prev_locals.data_ptr] # Allocates one memory cell. + + # Set up next iteration locals. + let next_locals : LoopLocals* = cast(ap, LoopLocals*) + next_locals.data_ptr = prev_locals.data_ptr + 1; ap++ + next_locals.hash_ptr = prev_locals.hash_ptr + HashBuiltin.SIZE; ap++ + next_locals.cur_hash = prev_locals.hash_ptr.result; ap++ + + # } while(n_remaining_elements != 0). + jmp hash_loop if n_remaining_elements != 0 + + # Return values from final iteration. + let final_locals : LoopLocals* = cast(ap - LoopLocals.SIZE, LoopLocals*) + let hash_ptr = final_locals.hash_ptr + return (hash=final_locals.cur_hash) +end + +# Adds each item in an array of items to the HashState. +# The array is represented by a pointer and a length. +func hash_update{hash_ptr : HashBuiltin*}( + hash_state_ptr : HashState*, data_ptr : felt*, data_length) -> ( + new_hash_state_ptr : HashState*): + alloc_locals + let (hash) = hash_update_inner( + data_ptr=data_ptr, data_length=data_length, hash=hash_state_ptr.current_hash) + let (__fp__, _) = get_fp_and_pc() + local new_hash_state : HashState + new_hash_state.current_hash = hash + assert new_hash_state.n_words = hash_state_ptr.n_words + data_length + return (new_hash_state_ptr=&new_hash_state) +end + +# Adds a single item to the HashState. +func hash_update_single{hash_ptr : HashBuiltin*}(hash_state_ptr : HashState*, item) -> ( + new_hash_state_ptr : HashState*): + alloc_locals + let (hash) = hash2(x=hash_state_ptr.current_hash, y=item) + let (__fp__, _) = get_fp_and_pc() + local new_hash_state : HashState + new_hash_state.current_hash = hash + assert new_hash_state.n_words = hash_state_ptr.n_words + 1 + return (new_hash_state_ptr=&new_hash_state) +end + +# Returns the hash result of the HashState. +func hash_finalize{hash_ptr : HashBuiltin*}(hash_state_ptr : HashState*) -> (hash): + return hash2(x=hash_state_ptr.current_hash, y=hash_state_ptr.n_words) +end diff --git a/src/starkware/cairo/common/invoke.cairo b/src/starkware/cairo/common/invoke.cairo new file mode 100644 index 0000000..e2b6331 --- /dev/null +++ b/src/starkware/cairo/common/invoke.cairo @@ -0,0 +1,19 @@ +# Calls func_ptr(args[0], args[1], ..., args[n_args - 1]) and forwards its return value. +func invoke(func_ptr, n_args : felt, args : felt*): + invoke_prepare_args(args_end=args + n_args, n_args=n_args) + call abs func_ptr + ret +end + +# Helper function for invoke(). +# Copies the memory range [args_end - n_args, args_end) to the memory range +# [final_ap - n_args, final_ap) where final_ap is the value of ap when the function returns. +func invoke_prepare_args(args_end : felt*, n_args : felt): + if n_args == 0: + return () + end + + invoke_prepare_args(args_end=args_end - 1, n_args=n_args - 1) + [ap] = [args_end - 1]; ap++ + return () +end diff --git a/src/starkware/cairo/common/math.cairo b/src/starkware/cairo/common/math.cairo new file mode 100644 index 0000000..9b112f5 --- /dev/null +++ b/src/starkware/cairo/common/math.cairo @@ -0,0 +1,272 @@ +# Inline functions with no locals. + +# Verifies that value != 0. The proof will fail otherwise. +func assert_not_zero(value): + %{ assert ids.value % PRIME != 0, f'assert_not_zero failed: {ids.value} = 0.' %} + if value == 0: + # If value == 0, add an unsatisfiable requirement. + value = 1 + end + + return () +end + +# Verifies that a != b. The proof will fail otherwise. +func assert_not_equal(a, b): + %{ assert (ids.a - ids.b) % PRIME != 0, f'assert_not_equal failed: {ids.a} = {ids.b}.' %} + if a == b: + # If a == b, add an unsatisfiable requirement. + [fp - 1] = [fp - 1] + 1 + end + + return () +end + +# Verifies that a >= 0 (or more precisely 0 <= a < RANGE_CHECK_BOUND). +func assert_nn{range_check_ptr}(a): + %{ assert 0 <= ids.a % PRIME < range_check_builtin.bound, f'a = {ids.a} is out of range.' %} + a = [range_check_ptr] + let range_check_ptr = range_check_ptr + 1 + return () +end + +# Verifies that a <= b (or more precisely 0 <= b - a < RANGE_CHECK_BOUND). +func assert_le{range_check_ptr}(a, b): + assert_nn(b - a) + return () +end + +# Verifies that a <= b - 1 (or more precisely 0 <= b - 1 - a < RANGE_CHECK_BOUND). +func assert_lt{range_check_ptr}(a, b): + assert_le(a, b - 1) + return () +end + +# Verifies that 0 <= a <= b. +# +# Prover assumption: a, b < RANGE_CHECK_BOUND. +func assert_nn_le{range_check_ptr}(a, b): + assert_nn(a) + assert_le(a, b) + return () +end + +# Asserts that value is in the range [lower, upper). +func assert_in_range{range_check_ptr}(value, lower, upper): + assert_le(lower, value) + assert_le(value, upper - 1) + return () +end + +# Asserts that a <= b. More specifically, asserts that b - a is in the range [0, 2**250). +# +# Prover assumptions: +# PRIME - 2**250 > 2**(250 - 128) + 1 * RC_BOUND. +func assert_le_250_bit{range_check_ptr}(a, b): + let low = [range_check_ptr] + let high = [range_check_ptr + 1] + let range_check_ptr = range_check_ptr + 2 + const UPPER_BOUND = %[2**(250)%] + const HIGH_PART_SHIFT = %[2**250 // 2**128 %] + tempvar diff = b - a + %{ + def as_int(val): + return val if val < PRIME // 2 else val - PRIME + # Soundness checks. + assert range_check_builtin.bound == 2**128 + assert ids.UPPER_BOUND == ids.HIGH_PART_SHIFT * range_check_builtin.bound + + # Correctness check. + diff = as_int(ids.diff) + values_msg = f'(a={as_int(ids.a)}, b={as_int(ids.b)}).' + assert diff < ids.UPPER_BOUND, f'(b - a)={diff} is outside of the valid range. {values_msg}' + assert PRIME - ids.UPPER_BOUND > (ids.HIGH_PART_SHIFT + 1) * range_check_builtin.bound + + assert diff >= 0, f'(b - a)={diff} < 0. {values_msg}' + + # Calculation for the assertion. + ids.high = ids.diff // ids.HIGH_PART_SHIFT + ids.low = ids.diff % ids.HIGH_PART_SHIFT + %} + + # Assuming the assert below, we have + # diff = high * HIGH_PART_SHIFT + low < (HIGH_PART_SHIFT + 1) * RC_BOUND < PRIME - UPPER_BOUND. + # If 0 <= b < a < UPPER_BOUND then diff < 0 => diff % P = PRIME - diff > PRIME - UPPER_BOUND. + # So given the soundness assumptions listed above it must be the case that a <= b. + assert diff = high * HIGH_PART_SHIFT + low + + return () +end + +# Splits the unsigned integer lift of a field element into the higher 128 bit and lower 128 bit. +# The unsigned integer lift is the unique integer in the range [0, PRIME) that represents the field +# element. +# For example, if value=17 * 2^128 + 8, then high=17 and low=8. +func split_felt{range_check_ptr}(value) -> (high, low): + const MAX_HIGH = %[(PRIME - 1) >> 128%] + const MAX_LOW = %[(PRIME - 1) & ((1 << 128) - 1)%] + + # Guess the low and high parts of the integer. + let low = [range_check_ptr] + let high = [range_check_ptr + 1] + let range_check_ptr = range_check_ptr + 2 + + %{ + assert PRIME < 2**256 + ids.low = ids.value & ((1 << 128) - 1) + ids.high = ids.value >> 128 + %} + assert value = high * %[2**128%] + low + if high == MAX_HIGH: + assert_le(low, MAX_LOW) + else: + assert_le(high, MAX_HIGH) + end + return (high=high, low=low) +end + +# Asserts that the unsigned integer lift (as a number in the range [0, PRIME)) of a is lower than +# or equal to that of b. +# See split_felt() for more details. +func assert_le_felt{range_check_ptr}(a, b): + %{ + assert (ids.a % PRIME) <= (ids.b % PRIME), \ + f'a = {ids.a % PRIME} is not less than or equal to b = {ids.b % PRIME}.' + %} + alloc_locals + let (local a_high, local a_low) = split_felt(a) + let (b_high, b_low) = split_felt(b) + + if a_high == b_high: + assert_le(a_low, b_low) + return () + end + assert_le(a_high, b_high) + return () +end + +# Asserts that the unsigned integer lift (as a number in the range [0, PRIME)) of a is lower than +# that of b. +func assert_lt_felt{range_check_ptr}(a, b): + %{ + assert (ids.a % PRIME) < (ids.b % PRIME), \ + f'a = {ids.a % PRIME} is not less than b = {ids.b % PRIME}.' + %} + alloc_locals + let (local a_high, local a_low) = split_felt(a) + let (b_high, b_low) = split_felt(b) + + if a_high == b_high: + assert_lt(a_low, b_low) + return () + end + assert_lt(a_high, b_high) + return () +end + +# Returns the absolute value of value. +# Prover asumption: -rc_bound < value < rc_bound. +func abs_value{range_check_ptr}(value) -> (abs_value): + %{ + from starkware.cairo.common.math_utils import is_positive + memory[ap] = 1 if is_positive( + value=ids.value, prime=PRIME, rc_bound=range_check_builtin.bound) else 0 + %} + jmp is_positive if [ap] != 0; ap++ + tempvar new_range_check_ptr = range_check_ptr + 1 + tempvar abs_value = value * (-1) + [range_check_ptr] = abs_value + let range_check_ptr = new_range_check_ptr + return (abs_value=abs_value) + + is_positive: + [range_check_ptr] = value + let range_check_ptr = range_check_ptr + 1 + return (abs_value=value) +end + +# Returns the sign of value: -1, 0 or 1. +# Prover asumption: -rc_bound < value < rc_bound. +func sign{range_check_ptr}(value) -> (sign): + if value == 0: + return (sign=0) + end + + %{ + from starkware.cairo.common.math_utils import is_positive + memory[ap] = 1 if is_positive( + value=ids.value, prime=PRIME, rc_bound=range_check_builtin.bound) else 0 + %} + jmp is_positive if [ap] != 0; ap++ + assert [range_check_ptr] = value * (-1) + let range_check_ptr = range_check_ptr + 1 + return (sign=-1) + + is_positive: + [range_check_ptr] = value + let range_check_ptr = range_check_ptr + 1 + return (sign=1) +end + +# Returns q and r such that: +# 0 <= q < rc_bound, 0 <= r < div and value = q * div + r. +# +# Assumption: 0 < div <= PRIME / rc_bound. +# Prover assumption: value / div < rc_bound. +# +# The value of div is restricted to make sure there is no overflow. +# q * div + r < (q + 1) * div <= rc_bound * (PRIME / rc_bound) = PRIME. +func unsigned_div_rem{range_check_ptr}(value, div) -> (q, r): + let r = [range_check_ptr] + let q = [range_check_ptr + 1] + let range_check_ptr = range_check_ptr + 2 + %{ + assert 0 < ids.div <= PRIME // range_check_builtin.bound, \ + f'div={hex(ids.div)} is out of the valid range.' + ids.q, ids.r = divmod(ids.value, ids.div) + %} + assert_le(r, div - 1) + + assert value = q * div + r + return (q, r) +end + +# Returns q and r such that. -bound <= q < bound, 0 <= r < div -1 and value = q * div + r. +# value < PRIME / 2 is considered positive and value > PRIME / 2 is considered negative. +# +# Assumptions: +# 0 < div <= PRIME / (rc_bound) +# bound <= rc_bound / 2. +# Prover assumption: -bound <= value / div < bound. +# +# The values of div and bound are restricted to make sure there is no overflow. +# q * div + r < (q + 1) * div <= rc_bound / 2 * (PRIME / rc_bound) +# q * div + r >= q * div >= -rc_bound / 2 * (PRIME / rc_bound) +func signed_div_rem{range_check_ptr}(value, div, bound) -> (q, r): + let r = [range_check_ptr] + let biased_q = [range_check_ptr + 1] # == q + bound. + let range_check_ptr = range_check_ptr + 2 + %{ + def as_int(val): + return val if val < PRIME // 2 else val - PRIME + + assert 0 < ids.div <= PRIME // range_check_builtin.bound, \ + f'div={hex(ids.div)} is out of the valid range.' + + assert ids.bound <= range_check_builtin.bound // 2, \ + f'bound={hex(ids.bound)} is out of the valid range.' + + int_value = as_int(ids.value) + q, ids.r = divmod(int_value, ids.div) + + assert -ids.bound <= q < ids.bound, \ + f'{int_value} / {ids.div} = {q} is out of the range [{-ids.bound}, {ids.bound}).' + + ids.biased_q = q + ids.bound + %} + let q = biased_q - bound + assert value = q * div + r + assert_le(r, div - 1) + assert_le(biased_q, 2 * bound - 1) + return (q, r) +end diff --git a/src/starkware/cairo/common/math_cmp.cairo b/src/starkware/cairo/common/math_cmp.cairo new file mode 100644 index 0000000..a6dae69 --- /dev/null +++ b/src/starkware/cairo/common/math_cmp.cairo @@ -0,0 +1,76 @@ +from starkware.cairo.common.math import assert_le_felt, assert_lt_felt + +const RC_BOUND = %[ 2**128 %] + +# Returns 1 if value != 0. Returns 0 otherwise. +func is_not_zero(value) -> (res): + if value == 0: + return (res=0) + end + + return (res=1) +end + +# Returns 1 if a >= 0 (or more precisely 0 <= a < RANGE_CHECK_BOUND). +# Returns 0 otherwise. +func is_nn{range_check_ptr}(a) -> (res): + %{ memory[ap] = 0 if 0 <= (ids.a % PRIME) < range_check_builtin.bound else 1 %} + jmp out_of_range if [ap] != 0; ap++ + [range_check_ptr] = a + let range_check_ptr = range_check_ptr + 1 + return (res=1) + + out_of_range: + %{ memory[ap] = 0 if 0 <= ((-ids.a - 1) % PRIME) < range_check_builtin.bound else 1 %} + jmp need_felt_comparison if [ap] != 0; ap++ + assert [range_check_ptr] = (-a) - 1 + let range_check_ptr = range_check_ptr + 1 + return (res=0) + + need_felt_comparison: + assert_le_felt(RC_BOUND, a) + return (res=0) +end + +# Returns 1 if a <= b (or more precisely 0 <= b - a < RANGE_CHECK_BOUND). +# Returns 0 otherwise. +func is_le{range_check_ptr}(a, b) -> (res): + return is_nn(b - a) +end + +# Returns 1 of 0 <= a <= b < RANGE_CHECK_BOUND. +# Returns 0 otherwise. +func is_nn_le{range_check_ptr}(a, b) -> (res): + let (res) = is_nn(a) + if res == 0: + return (res=res) + end + return is_le(a, b) +end + +# Returns 1 if value is in the range [lower, upper). +# Returns 0 otherwise. +# Assumptions: +# upper - lower <= RC_BOUND +func is_in_range{range_check_ptr}(value, lower, upper) -> (res): + let (res) = is_le(lower, value) + if res == 0: + return (res=res) + end + return is_le(value, upper - 1) +end + +# Checks if the unsigned integer lift (as a number in the range [0, PRIME)) of a is lower than +# or equal to that of b. +# See split_felt() for more details. +# Returns 1 if true, 0 otherwise. +func is_le_felt{range_check_ptr}(a, b) -> (res): + %{ memory[ap] = 0 if (ids.a % PRIME) <= (ids.b % PRIME) else 1 %} + jmp not_le if [ap] != 0; ap++ + assert_le_felt(a, b) + return (res=1) + + not_le: + assert_lt_felt(b, a) + return (res=0) +end diff --git a/src/starkware/cairo/common/math_utils.py b/src/starkware/cairo/common/math_utils.py new file mode 100644 index 0000000..449c1f4 --- /dev/null +++ b/src/starkware/cairo/common/math_utils.py @@ -0,0 +1,17 @@ +def as_int(val, prime): + """ + Returns the lift of the given field element, val, as an integer in the range + (-prime/2, prime/2). + """ + return val if val < prime // 2 else val - prime + + +def is_positive(value, prime, rc_bound): + """ + Returns True if the lift of the given field element, as an integer in the range + (-rc_bound, rc_bound), is positive. + Raises an exception if the element is not within that range. + """ + val = as_int(value, prime) + assert abs(val) < rc_bound, f'value={val} is out of the valid range.' + return val > 0 diff --git a/src/starkware/cairo/common/memcpy.cairo b/src/starkware/cairo/common/memcpy.cairo new file mode 100644 index 0000000..09487ee --- /dev/null +++ b/src/starkware/cairo/common/memcpy.cairo @@ -0,0 +1,37 @@ +# Copies len field elements from src to dst. +func memcpy(dst : felt*, src : felt*, len): + struct LoopFrame: + member dst : felt* + member src : felt* + end + + if len == 0: + return () + end + + let frame = cast(ap, LoopFrame*) + %{ vm_enter_scope({'n': ids.len}) %} + frame.dst = dst; ap++ + frame.src = src; ap++ + + loop: + let frame = cast(ap - LoopFrame.SIZE, LoopFrame*) + assert [frame.dst] = [frame.src] + + let continue_copying = [ap] + # Reserve space for continue_copying. + let next_frame = cast(ap + 1, LoopFrame*) + next_frame.dst = frame.dst + 1; ap++ + next_frame.src = frame.src + 1; ap++ + %{ + n -= 1 + ids.continue_copying = 1 if n > 0 else 0 + %} + static_assert next_frame + LoopFrame.SIZE == ap + 1 + jmp loop if continue_copying != 0; ap++ + # Assert that the loop executed len times. + len = cast(next_frame.src, felt) - cast(src, felt) + + %{ vm_exit_scope() %} + return () +end diff --git a/src/starkware/cairo/common/merkle_multi_update.cairo b/src/starkware/cairo/common/merkle_multi_update.cairo new file mode 100644 index 0000000..9e0e00f --- /dev/null +++ b/src/starkware/cairo/common/merkle_multi_update.cairo @@ -0,0 +1,168 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.dict_access import DictAccess + +# Performs an efficient update of multiple leaves in a Merkle tree. +# +# Arguments: +# update_ptr - a list of DictAccess instances sorted by key (e.g., the result of squash_dict). +# height - the height of the merkle tree. +# prev_root - the value of the root before the update. +# new_root - the value of the root after the update. +# +# Hint arguments: +# preimage - a dictionary from the hash value of a merkle node to the pair of children values. +# +# Implicit arguments: +# hash_ptr - hash builtin pointer. +# +# Assumptions: The keys in the update_ptr list are unique and sorted. +# Guarantees: All the keys in the update_ptr list are < 2**height. +# +# Pseudocode: +# def diff(prev, new, height): +# if height == 0: return [(prev,new)] +# if prev.left==new.left: return diff(prev.right, new.right, height - 1) +# if prev.right==new.right: return diff(prev.left, new.left, height - 1) +# return diff(prev.left, new.left, height - 1) + \ +# diff(prev.right, new.right, height - 1) +func merkle_multi_update{hash_ptr : HashBuiltin*}( + update_ptr : DictAccess*, n_updates, height, prev_root, new_root): + if n_updates == 0: + prev_root = new_root + return () + end + + %{ + from starkware.python.merkle_tree import build_update_tree + + # Build modifications list. + modifications = [] + for i in range(ids.n_updates): + curr_update_ptr = ids.update_ptr.address_ + i * ids.DictAccess.SIZE + modifications.append(( + memory[curr_update_ptr + ids.DictAccess.key], + memory[curr_update_ptr + ids.DictAccess.new_value])) + + node = build_update_tree(ids.height, modifications) + del modifications + vm_enter_scope(dict(node=node, preimage=preimage)) + %} + let orig_update_ptr = update_ptr + with update_ptr: + merkle_multi_update_inner(height=height, prev_root=prev_root, new_root=new_root, index=0) + end + assert update_ptr = orig_update_ptr + n_updates * DictAccess.SIZE + %{ vm_exit_scope() %} + return () +end + +# Helper function for merkle_multi_update(). +func merkle_multi_update_inner{hash_ptr : HashBuiltin*, update_ptr : DictAccess*}( + height, prev_root, new_root, index): + let hash0 : HashBuiltin* = hash_ptr + let hash1 : HashBuiltin* = hash_ptr + HashBuiltin.SIZE + %{ + if ids.height == 0: + assert node == ids.new_root, f'Expected node {ids.new_root}. Got {node}.' + case = 'leaf' + else: + prev_left, prev_right = preimage[ids.prev_root] + new_left, new_right = preimage[ids.new_root] + + left_child, right_child = node + if left_child is None: + assert right_child is not None, 'No updates in tree' + case = 'right' + elif right_child is None: + case = 'left' + else: + case = 'both' + + # Fill non deterministic hashes. + hash_ptr = ids.hash_ptr.address_ + memory[hash_ptr + 0 * ids.HashBuiltin.SIZE + ids.HashBuiltin.x] = prev_left + memory[hash_ptr + 0 * ids.HashBuiltin.SIZE + ids.HashBuiltin.y] = prev_right + memory[hash_ptr + 1 * ids.HashBuiltin.SIZE + ids.HashBuiltin.x] = new_left + memory[hash_ptr + 1 * ids.HashBuiltin.SIZE + ids.HashBuiltin.y] = new_right + + memory[ap] = int(case != 'right') + %} + jmp not_right if [ap] != 0; ap++ + + update_right: + let hash_ptr = hash_ptr + 2 * HashBuiltin.SIZE + prev_root = hash0.result + new_root = hash1.result + + # Make sure the same authentication path is used. + assert hash0.x = hash1.x + + # Call merkle_multi_update_inner recursively. + %{ vm_enter_scope(dict(node=right_child, preimage=preimage)) %} + + merkle_multi_update_inner( + height=height - 1, prev_root=hash0.y, new_root=hash1.y, index=index * 2 + 1) + %{ vm_exit_scope() %} + return () + + not_right: + %{ memory[ap] = int(case != 'left') %} + jmp not_left if [ap] != 0; ap++ + + update_left: + let hash_ptr = hash_ptr + 2 * HashBuiltin.SIZE + prev_root = hash0.result + new_root = hash1.result + + # Make sure the same authentication path is used. + assert hash0.y = hash1.y + + # Call merkle_multi_update_inner recursively. + %{ vm_enter_scope(dict(node=left_child, preimage=preimage)) %} + merkle_multi_update_inner( + height=height - 1, prev_root=hash0.x, new_root=hash1.x, index=index * 2) + %{ vm_exit_scope() %} + return () + + not_left: + jmp update_both if height != 0 + + update_leaf: + # Note: height may underflow, but in order to reach 0 (which is verified here), we will need + # more steps than the field characteristic. The assumption is that it is not feasible. + + # Write the update. + %{ assert case == 'leaf' %} + index = update_ptr.key + prev_root = update_ptr.prev_value + new_root = update_ptr.new_value + let update_ptr = update_ptr + DictAccess.SIZE + + # Return values. + return () + + update_both: + # Locals 0 and 1 are taken by non deterministic jumps. + let local_left_index = [fp + 2] + %{ assert case == 'both' %} + local_left_index = index * 2; ap++ + + let hash_ptr = hash_ptr + 2 * HashBuiltin.SIZE + prev_root = hash0.result + new_root = hash1.result + + # Update left. + %{ vm_enter_scope(dict(node=left_child, preimage=preimage)) %} + merkle_multi_update_inner( + height=height - 1, prev_root=hash0.x, new_root=hash1.x, index=index * 2) + %{ vm_exit_scope() %} + + # Update right. + # Push height to workaround one hint per line limitation. + tempvar height_minus_1 = height - 1 + %{ vm_enter_scope(dict(node=right_child, preimage=preimage)) %} + merkle_multi_update_inner( + height=height_minus_1, prev_root=hash0.y, new_root=hash1.y, index=local_left_index + 1) + %{ vm_exit_scope() %} + return () +end diff --git a/src/starkware/cairo/common/merkle_update.cairo b/src/starkware/cairo/common/merkle_update.cairo new file mode 100644 index 0000000..a39af85 --- /dev/null +++ b/src/starkware/cairo/common/merkle_update.cairo @@ -0,0 +1,71 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin + +# Performs an update for a single leaf (index) in a Merkle tree (where 0 <= index < 2^height). +# Updates the leaf from prev_leaf to new_leaf, and returns the previous and new roots of the +# Merkle tree resulting from the change. +# In particular, given a secret authentication path (of the siblings of the nodes in the path from +# the root to the leaf), this function computes the roots twice - once with prev_leaf and once with +# new_leaf, where the verifier is guaranteed that the same authentication path is used. +func merkle_update{hash_ptr : HashBuiltin*}(height, prev_leaf, new_leaf, index) -> ( + prev_root, new_root): + if height == 0: + # Assert that index is 0. + index = 0 + # Return the two leaves and the Pedersen pointer. + %{ + # Check that auth_path had the right number of elements. + assert len(auth_path) == 0, 'Got too many values in auth_path.' + %} + return (prev_root=prev_leaf, new_root=new_leaf) + end + + let prev_node_hash = hash_ptr + let new_node_hash = hash_ptr + HashBuiltin.SIZE + let hash_ptr = hash_ptr + 2 * HashBuiltin.SIZE + + %{ memory[ap] = ids.index % 2 %} + jmp update_right if [ap] != 0; ap++ + + update_left: + %{ + # Hash hints. + sibling = auth_path.pop() + ids.prev_node_hash.y = sibling + ids.new_node_hash.y = sibling + %} + prev_leaf = prev_node_hash.x + new_leaf = new_node_hash.x + + # Make sure the same authentication path is used. + let right_sibling = ap + [right_sibling] = prev_node_hash.y + [right_sibling] = new_node_hash.y; ap++ + + # Call merkle_update recursively. + return merkle_update( + height=height - 1, + prev_leaf=prev_node_hash.result, + new_leaf=new_node_hash.result, + index=index / 2) + + update_right: + %{ + # Hash hints. + sibling = auth_path.pop() + ids.prev_node_hash.x = sibling + ids.new_node_hash.x = sibling + %} + prev_leaf = prev_node_hash.y + new_leaf = new_node_hash.y + + # Make sure the same authentication path is used. + let left_sibling = ap + [left_sibling] = prev_node_hash.x + [left_sibling] = new_node_hash.x; ap++ + + return merkle_update( + height=height - 1, + prev_leaf=prev_node_hash.result, + new_leaf=new_node_hash.result, + index=(index - 1) / 2) +end diff --git a/src/starkware/cairo/common/registers.cairo b/src/starkware/cairo/common/registers.cairo new file mode 100644 index 0000000..be2758a --- /dev/null +++ b/src/starkware/cairo/common/registers.cairo @@ -0,0 +1,47 @@ +# Returns the contents of the fp and pc registers of the calling function. +# The pc register's value is the address of the instruction that follows directly after the +# invocation of get_fp_and_pc(). +func get_fp_and_pc() -> (fp_val, pc_val): + # The call instruction itself already places the old fp and the return pc at [ap - 2], [ap - 1]. + return (fp_val=[ap - 2], pc_val=[ap - 1]) +end + +# Returns the content of the ap register just before this function was invoked. +func get_ap() -> (ap_val): + # Once get_ap() is invoked, fp points to ap + 2 (since the call instruction placed the old fp + # and pc in memory, advancing ap accordingly). + # Calling dummy_func places fp and pc at [fp], [fp + 1] (respectively), and advances ap by 2. + # Hence, going two cells above we get [fp] = ap + 2, and by subtracting 2 we get the desired ap + # value. + call dummy_func + return (ap_val=[ap - 2] - 2) +end + +func dummy_func(): + return () +end + +# Takes the value of a label (relative to program base) and returns the actual runtime address of +# that label in the memory. +# +# Example usage: +# +# func do_callback(...): +# ... +# end +# +# func do_thing_then_callback(callback): +# ... +# call abs callback +# end +# +# func main(): +# let (callback_address) = get_label_location(do_callback) +# do_thing_then_callback(callback=callback_address) +# end +func get_label_location(label_value) -> (res): + let (_, pc_val) = get_fp_and_pc() + + ret_pc_label: + return (res=label_value + pc_val - ret_pc_label) +end diff --git a/src/starkware/cairo/common/serialize.cairo b/src/starkware/cairo/common/serialize.cairo new file mode 100644 index 0000000..99ecec4 --- /dev/null +++ b/src/starkware/cairo/common/serialize.cairo @@ -0,0 +1,51 @@ +# Appends a single word to the output pointer, and returns the pointer to the next output cell. +func serialize_word{output_ptr : felt*}(word): + assert [output_ptr] = word + let output_ptr = output_ptr + 1 + return () +end + +# Array right fold: computes the following: +# callback(callback(... callback(value, a[n-1]) ..., a[1]), a[0]) +# Arguments: +# value - the initial value. +# array - a pointer to an array. +# elm_size - the size of an element in the array. +# n_elms - the number of elements in the array. +# callback - a function pointer to the callback. Expected signature: (felt, T*) -> felt. +func array_rfold(value, array : felt*, n_elms, elm_size, callback) -> (res): + if n_elms == 0: + return (value) + end + + [ap] = value; ap++ + [ap] = array; ap++ + call abs callback + # [ap - 1] holds the return value of callback. + return array_rfold( + value=[ap - 1], + array=array + elm_size, + n_elms=n_elms - 1, + elm_size=elm_size, + callback=callback) +end + +# Serializes an array of objects to output_ptr, and returns the pointer to the next output cell. +# The format is: len(array) || callback(a[0]) || ... || callback(a[n-1]) . +# Arguments: +# output_ptr - the pointer to serialize to. +# array - a pointer to an array. +# elm_size - the size of an element in the array. +# n_elms - the number of elements in the array. +# callback - a function pointer to the serialize function of a single element. +# Expected signature: (felt, T*) -> felt. +func serialize_array{output_ptr : felt*}(array : felt*, n_elms, elm_size, callback): + serialize_word(n_elms) + let (output_ptr : felt*) = array_rfold( + value=cast(output_ptr, felt), + array=array, + n_elms=n_elms, + elm_size=elm_size, + callback=callback) + return () +end diff --git a/src/starkware/cairo/common/signature.cairo b/src/starkware/cairo/common/signature.cairo new file mode 100644 index 0000000..3fb1aa6 --- /dev/null +++ b/src/starkware/cairo/common/signature.cairo @@ -0,0 +1,15 @@ +from starkware.cairo.common.cairo_builtins import SignatureBuiltin + +# Verifies that the prover knows a signature of the given public_key on the given message. +# +# Prover assumption: (signature_r, signature_s) is a valid signature for the given public_key +# on the given message. +func verify_ecdsa_signature{ecdsa_ptr : SignatureBuiltin*}( + message, public_key, signature_r, signature_s): + %{ ecdsa_builtin.add_signature(ids.ecdsa_ptr.address_, (ids.signature_r, ids.signature_s)) %} + assert ecdsa_ptr.message = message + assert ecdsa_ptr.pub_key = public_key + + let ecdsa_ptr = ecdsa_ptr + SignatureBuiltin.SIZE + return () +end diff --git a/src/starkware/cairo/common/small_merkle_tree.cairo b/src/starkware/cairo/common/small_merkle_tree.cairo new file mode 100644 index 0000000..aadfd8d --- /dev/null +++ b/src/starkware/cairo/common/small_merkle_tree.cairo @@ -0,0 +1,101 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.dict import DictAccess +from starkware.cairo.common.merkle_multi_update import merkle_multi_update + +# Performs an efficient update of multiple leaves in a Merkle tree, based on the given squashed +# dict, assuming the merkle tree is small enough to be loaded to the memory. +# +# This function computes the Merkle authentication paths internally and +# does not require any hint arguments, therefore it's usually easier to use. +# The input dict must be created using the higher-level dict functions (see dict.cairo), which add +# information about all the non-default leaves in the hints (not just the leaves that were changed). +# +# Usage example: +# %{ initial_dict = {1: 2, 3: 4, 5: 6} %} +# let (dict_ptr_start) = dict_new() +# let dict_ptr = dict_ptr_start +# let (dict_ptr) = dict_update(dict_ptr=dict_ptr, key=1, prev_value=2, new_value=20) +# let (range_check_ptr, squashed_dict_start, squashed_dict_end) = dict_squash( +# range_check_ptr=range_check_ptr, +# dict_accesses_start=dict_ptr_start, +# dict_accesses_end=dict_ptr) +# const HEIGHT = 3 +# let (prev_root, new_root) = small_merkle_tree( +# squashed_dict_start, squashed_dict_end, HEIGHT) +# +# In this example prev_root is the Merkle root of [0, 2, 0, 4, 0, 6, 0, 0], and new_root +# is the Merkle root of [0, 20, 0, 4, 0, 6, 0, 0]. +# Note that from the point of view of the verifier, all it knows is that leaf 1 changed from 2 to +# 20 -- it doesn't know anything about the other leaves (except that they haven't changed). +# +# Arguments: +# squashed_dict, squashed_dict_end - a list of DictAccess instances sorted by key +# (e.g., the result of dict_squash). +# height - the height of the merkle tree. +# +# Implicit arguments: +# hash_ptr - hash builtin pointer. +# +# Returns: +# prev_root - the value of the root before the update. +# new_root - the value of the root after the update. +# +# Assumptions: The keys in the squashed_dict are unique and sorted. +# +# Prover assumptions: +# * squashed_dict was created using the higher-level API dict_squash() (rather than squash_dict()). +# * This function can be used for (relatively) small Merkle trees whose leaves can be loaded +# to the memory. +func small_merkle_tree{hash_ptr : HashBuiltin*}( + squashed_dict_start : DictAccess*, squashed_dict_end : DictAccess*, height : felt) -> ( + prev_root : felt, new_root : felt): + %{ vm_enter_scope({'__dict_manager': __dict_manager}) %} + alloc_locals + # Allocate memory cells for the roots. + local prev_root + local new_root + %{ + # Compute the roots and the preimage dictionary. + from starkware.cairo.common.small_merkle_tree import get_preimage_dictionary + from starkware.python.math_utils import safe_div + + new_dict = __dict_manager.get_dict(ids.squashed_dict_end.address_) + + DICT_ACCESS_SIZE = ids.DictAccess.SIZE + squashed_dict_start = ids.squashed_dict_start.address_ + squashed_dict_size = ids.squashed_dict_end.address_ - squashed_dict_start + assert squashed_dict_size >= 0 and squashed_dict_size % DICT_ACCESS_SIZE == 0, \ + f'squashed_dict size must be non-negative and divisible by DictAccess.SIZE. ' \ + f'Found: {squashed_dict_size}.' + squashed_dict_length = safe_div(squashed_dict_size, DICT_ACCESS_SIZE) + + # Compute the modifications backwards: from the new values to the previous values. + modifications = [] + for i in range(squashed_dict_length): + key = memory[squashed_dict_start + i * DICT_ACCESS_SIZE + ids.DictAccess.key] + prev_value = memory[ + squashed_dict_start + i * DICT_ACCESS_SIZE + ids.DictAccess.prev_value] + new_value = memory[ + squashed_dict_start + i * DICT_ACCESS_SIZE + ids.DictAccess.new_value] + assert new_dict[key] == new_value, \ + f'Inconsistent dictionary values. Expected new value: {new_dict[key]}, ' \ + f'found: {new_value}' + modifications.append((key, prev_value)) + + ids.new_root, ids.prev_root, preimage = get_preimage_dictionary( + initial_leaves=new_dict.items(), + modifications=modifications, + tree_height=ids.height, + default_leaf=0) + %} + + # Call merkle_multi_update() to verify the two roots. + merkle_multi_update( + update_ptr=squashed_dict_start, + n_updates=(squashed_dict_end - squashed_dict_start) / DictAccess.SIZE, + height=height, + prev_root=prev_root, + new_root=new_root) + %{ vm_exit_scope() %} + return (prev_root=prev_root, new_root=new_root) +end diff --git a/src/starkware/cairo/common/small_merkle_tree.py b/src/starkware/cairo/common/small_merkle_tree.py new file mode 100644 index 0000000..99db510 --- /dev/null +++ b/src/starkware/cairo/common/small_merkle_tree.py @@ -0,0 +1,69 @@ +from typing import Collection, Dict, Tuple + +from starkware.cairo.lang.vm.crypto import pedersen_hash + + +class MerkleTree: + def __init__(self, tree_height: int, default_leaf: int): + self.tree_height = tree_height + self.default_leaf = default_leaf + + # Compute the root of an empty tree. + empty_tree_root = default_leaf + for _ in range(tree_height): + empty_tree_root = pedersen_hash(empty_tree_root, empty_tree_root) + + # A map from node indices to their values. + self.node_values: Dict[int, int] = {1: empty_tree_root} + # A map from node hash to its two children. + self.preimage: Dict[int, Tuple[int, int]] = {} + + def compute_merkle_root(self, modifications: Collection[Tuple[int, int]]): + """ + Applies the given modifications (a list of (leaf index, value)) to the tree and returns + the Merkle root. + """ + + if len(modifications) == 0: + return self.node_values[1] + + default_node = self.default_leaf + indices = set() + leaves_offset = 2 ** self.tree_height + for index, value in modifications: + node_index = leaves_offset + index + self.node_values[node_index] = value + indices.add(node_index // 2) + for _ in range(self.tree_height): + new_indices = set() + while len(indices) > 0: + index = indices.pop() + left = self.node_values.get(2 * index, default_node) + right = self.node_values.get(2 * index + 1, default_node) + self.node_values[index] = node_hash = pedersen_hash(left, right) + self.preimage[node_hash] = (left, right) + new_indices.add(index // 2) + default_node = pedersen_hash(default_node, default_node) + indices = new_indices + assert indices == {0} + return self.node_values[1] + + +def get_preimage_dictionary( + initial_leaves: Collection[Tuple[int, int]], modifications: Collection[Tuple[int, int]], + tree_height: int, default_leaf: int) -> Tuple[int, int, Dict[int, Tuple[int, int]]]: + """ + Given a set of initial leaves and a set of modifications + (both are maps from leaf index to value, where all the leaves in `modifications` appear + in `initial_leaves`). + Constructs two merkle trees, before and after the modifications. + Returns (root_before, root_after, preimage) where preimage is a dictionary from a node to + its two children. + """ + + merkle_tree = MerkleTree(tree_height=tree_height, default_leaf=default_leaf) + + root_before = merkle_tree.compute_merkle_root(modifications=initial_leaves) + root_after = merkle_tree.compute_merkle_root(modifications=modifications) + + return root_before, root_after, merkle_tree.preimage diff --git a/src/starkware/cairo/common/small_merkle_tree_test.py b/src/starkware/cairo/common/small_merkle_tree_test.py new file mode 100644 index 0000000..ac79dae --- /dev/null +++ b/src/starkware/cairo/common/small_merkle_tree_test.py @@ -0,0 +1,50 @@ +import os + +from starkware.cairo.common.dict import DictManager +from starkware.cairo.common.small_merkle_tree import MerkleTree +from starkware.cairo.common.test_utils import CairoFunctionRunner +from starkware.cairo.lang.builtins.hash.hash_builtin_runner import CELLS_PER_HASH +from starkware.cairo.lang.compiler.cairo_compile import compile_cairo_files +from starkware.native_crypto.native_crypto import pedersen_hash + +CAIRO_FILE = os.path.join(os.path.dirname(__file__), 'small_merkle_tree.cairo') +PRIME = 2**251 + 17 * 2**192 + 1 +MERKLE_HEIGHT = 2 + + +def test_cairo_merkle_multi_update(): + program = compile_cairo_files([CAIRO_FILE], prime=PRIME, debug_info=True) + runner = CairoFunctionRunner(program) + + dict_manager = DictManager() + squashed_dict_start = dict_manager.new_dict( + segments=runner.segments, initial_dict={1: 10, 2: 20, 3: 30}) + + # Change the value at 1 from 10 to 11 and at 3 from 30 to 31. + squashed_dict = [1, 10, 11, 3, 30, 31] + squashed_dict_end = runner.segments.write_arg(ptr=squashed_dict_start, arg=squashed_dict) + dict_tracker = dict_manager.get_tracker(squashed_dict_start) + dict_tracker.current_ptr = squashed_dict_end + dict_tracker.data[1] = 11 + dict_tracker.data[3] = 31 + + runner.run( + 'small_merkle_tree', runner.hash_builtin.base, squashed_dict_start, squashed_dict_end, + MERKLE_HEIGHT, hint_locals=dict(__dict_manager=dict_manager)) + hash_ptr, prev_root, new_root = runner.get_return_values(3) + N_MERKLE_TREES = 2 + N_HASHES_PER_TREE = 3 + assert hash_ptr == \ + runner.hash_builtin.base + N_MERKLE_TREES * N_HASHES_PER_TREE * CELLS_PER_HASH + assert prev_root == pedersen_hash(pedersen_hash(0, 10), pedersen_hash(20, 30)) + assert new_root == pedersen_hash(pedersen_hash(0, 11), pedersen_hash(20, 31)) + + +def test_merkle_tree(): + tree = MerkleTree(tree_height=2, default_leaf=10) + expected_hash = pedersen_hash(pedersen_hash(10, 10), pedersen_hash(10, 10)) + assert tree.compute_merkle_root([]) == expected_hash + # Change leaf 1 to 7. + expected_hash = pedersen_hash(pedersen_hash(10, 7), pedersen_hash(10, 10)) + assert tree.compute_merkle_root([(1, 7)]) == expected_hash + assert tree.compute_merkle_root([]) == expected_hash diff --git a/src/starkware/cairo/common/squash_dict.cairo b/src/starkware/cairo/common/squash_dict.cairo new file mode 100644 index 0000000..7891674 --- /dev/null +++ b/src/starkware/cairo/common/squash_dict.cairo @@ -0,0 +1,231 @@ +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.math import assert_lt_felt + +# Verifies that dict_accesses lists valid chronological accesses (and updates) +# to a mutable dictionary and outputs a squashed dict with one DictAccess instance per key +# (value before and value after) which summarizes all the changes to that key. +# +# Example: +# Input: {(key1, 0, 2), (key1, 2, 7), (key2, 4, 1), (key1, 7, 5), (key2, 1, 2)} +# Output: {(key1, 0, 5), (key2, 4, 2)} +# +# Arguments: +# dict_accesses - a pointer to the beginning of an array of DictAccess instances. The format of each +# entry is a triplet (key, prev_value, new_value). +# dict_accesses_end - a pointer to the end of said array. +# squashed_dict - a pointer to an output array, which will be filled with +# DictAccess instances sorted by key with the first and last value for each key. +# +# Returns: +# squashed_dict - end pointer to squashed_dict. +# +# Implicit arguments: +# range_check_ptr - range check builtin pointer. +func squash_dict{range_check_ptr}( + dict_accesses : DictAccess*, dict_accesses_end : DictAccess*, + squashed_dict : DictAccess*) -> (squashed_dict : DictAccess*): + let ptr_diff = [ap] + %{ vm_enter_scope() %} + ptr_diff = dict_accesses_end - dict_accesses; ap++ + + if ptr_diff == 0: + # Access array is empty, nothing to check. + %{ vm_exit_scope() %} + return (squashed_dict=squashed_dict) + end + let first_key = [fp + 1] + let big_keys = [fp + 2] + ap += 2 + tempvar n_accesses = ptr_diff / DictAccess.SIZE + %{ + assert ids.ptr_diff % ids.DictAccess.SIZE == 0, \ + 'Accesses array size must be divisible by DictAccess.SIZE' + # A map from key to the list of indices accessing it. + access_indices = {} + for i in range(ids.n_accesses): + key = memory[ids.dict_accesses.address_ + ids.DictAccess.SIZE * i] + access_indices.setdefault(key, []).append(i) + # Descending list of keys. + keys = sorted(access_indices.keys(), reverse=True) + # Are the keys used bigger than range_check bound. + ids.big_keys = 1 if keys[0] >= range_check_builtin.bound else 0 + ids.first_key = key = keys.pop() + %} + + # Call inner. + if big_keys != 0: + tempvar range_check_ptr = range_check_ptr + else: + assert first_key = [range_check_ptr] + tempvar range_check_ptr = range_check_ptr + 1 + end + let (range_check_ptr, squashed_dict) = squash_dict_inner( + range_check_ptr=range_check_ptr, + dict_accesses=dict_accesses, + dict_accesses_end_minus1=dict_accesses_end - 1, + key=first_key, + remaining_accesses=n_accesses, + squashed_dict=squashed_dict, + big_keys=big_keys) + %{ vm_exit_scope() %} + return (squashed_dict=squashed_dict) +end + +# Inner tail-recursive function for squash_dict. +# +# Arguments: +# range_check_ptr - range check builtin pointer. +# dict_accesses - a pointer to the beginning of an array of DictAccess instances. +# dict_accesses_end_minus1 - a pointer to the end of said array, minus 1. +# key - current DictAccess key to check. +# remaining_accesses - remaining number of accesses that need to be accounted for. Starts with +# the total number of entries in dict_accesses array, and slowly decreases until it reaches 0. +# squashed_dict - a pointer to an output array, which will be filled with +# DictAccess instances sorted by key with the first and last value for each key. +# +# Hints: +# keys - a descending list of the keys for which we have accesses. Destroyed in the process. +# access_indices - A map from key to a descending list of indices in the dict_accesses array that +# access this key. Destroyed in the process. +# +# Returns: +# range_check_ptr - updated range check builtin pointer. +# squashed_dict - end pointer to squashed_dict. +func squash_dict_inner( + range_check_ptr, dict_accesses : DictAccess*, dict_accesses_end_minus1 : felt*, key, + remaining_accesses, squashed_dict : DictAccess*, big_keys) -> ( + range_check_ptr, squashed_dict : DictAccess*): + alloc_locals + + let dict_diff : DictAccess* = squashed_dict + + # Loop to verify chronological accesses to the key. + # These values are not needed from previous iteration. + struct LoopTemps: + member index_delta_minus1 : felt + member index_delta : felt + member ptr_delta : felt + member should_continue : felt + end + # These values are needed from previous iteration. + struct LoopLocals: + member value : felt + member access_ptr : DictAccess* + member range_check_ptr : felt + end + + # Prepare first iteration. + %{ + current_access_indices = sorted(access_indices[key])[::-1] + current_access_index = current_access_indices.pop() + memory[ids.range_check_ptr] = current_access_index + %} + # Check that first access_index >= 0. + tempvar current_access_index = [range_check_ptr] + tempvar ptr_delta = current_access_index * DictAccess.SIZE + + let first_loop_locals = cast(ap, LoopLocals*) + first_loop_locals.access_ptr = dict_accesses + ptr_delta; ap++ + let first_access : DictAccess* = first_loop_locals.access_ptr + first_loop_locals.value = first_access.new_value; ap++ + first_loop_locals.range_check_ptr = range_check_ptr + 1; ap++ + + # Verify first key. + key = first_access.key + + # Write key and first value to dict_diff. + key = dict_diff.key + # Use a local variable, instead of a tempvar, to avoid increasing ap. + local first_value = first_access.prev_value + assert first_value = dict_diff.prev_value + + # Skip loop non-deterministically if necessary. + local should_skip_loop + %{ ids.should_skip_loop = 0 if current_access_indices else 1 %} + jmp skip_loop if should_skip_loop != 0 + + loop: + let prev_loop_locals = cast(ap - LoopLocals.SIZE, LoopLocals*) + let loop_temps = cast(ap, LoopTemps*) + let loop_locals = cast(ap + LoopTemps.SIZE, LoopLocals*) + + # Check access_index. + %{ + new_access_index = current_access_indices.pop() + ids.loop_temps.index_delta_minus1 = new_access_index - current_access_index - 1 + current_access_index = new_access_index + %} + # Check that new access_index > prev access_index. + loop_temps.index_delta_minus1 = [prev_loop_locals.range_check_ptr]; ap++ + loop_temps.index_delta = loop_temps.index_delta_minus1 + 1; ap++ + loop_temps.ptr_delta = loop_temps.index_delta * DictAccess.SIZE; ap++ + loop_locals.access_ptr = prev_loop_locals.access_ptr + loop_temps.ptr_delta; ap++ + + # Check valid transition. + let access : DictAccess* = loop_locals.access_ptr + prev_loop_locals.value = access.prev_value + loop_locals.value = access.new_value; ap++ + + # Verify key. + key = access.key + + # Next range_check_ptr. + loop_locals.range_check_ptr = prev_loop_locals.range_check_ptr + 1; ap++ + + %{ ids.loop_temps.should_continue = 1 if current_access_indices else 0 %} + jmp loop if loop_temps.should_continue != 0; ap++ + + skip_loop: + let last_loop_locals = cast(ap - LoopLocals.SIZE, LoopLocals*) + + # Check if address is out of bounds. + %{ assert len(current_access_indices) == 0 %} + [ap] = dict_accesses_end_minus1 - cast(last_loop_locals.access_ptr, felt) + [ap] = [last_loop_locals.range_check_ptr]; ap++ + tempvar n_used_accesses = last_loop_locals.range_check_ptr - range_check_ptr + %{ assert ids.n_used_accesses == len(access_indices[key]) %} + + # Write last value to dict_diff. + last_loop_locals.value = dict_diff.new_value + + let range_check_ptr = last_loop_locals.range_check_ptr + 1 + tempvar remaining_accesses = remaining_accesses - n_used_accesses + + # Exit recursion when done. + if remaining_accesses == 0: + %{ assert len(keys) == 0 %} + return (range_check_ptr=range_check_ptr, squashed_dict=squashed_dict + DictAccess.SIZE) + end + + let next_key = [ap] + ap += 1 + # Guess next_key and check that next_key > key. + %{ + assert len(keys) > 0, 'No keys left but remaining_accesses > 0.' + ids.next_key = key = keys.pop() + %} + + if big_keys != 0: + assert_lt_felt{range_check_ptr=range_check_ptr}(a=key, b=next_key) + tempvar dict_accesses = dict_accesses + tempvar dict_accesses_end_minus1 = dict_accesses_end_minus1 + tempvar next_key = next_key + tempvar remaining_accesses = remaining_accesses + else: + assert [range_check_ptr] = next_key - (key + 1) + tempvar range_check_ptr = range_check_ptr + 1 + tempvar dict_accesses = dict_accesses + tempvar dict_accesses_end_minus1 = dict_accesses_end_minus1 + tempvar next_key = next_key + tempvar remaining_accesses = remaining_accesses + end + + return squash_dict_inner( + range_check_ptr=range_check_ptr, + dict_accesses=dict_accesses, + dict_accesses_end_minus1=dict_accesses_end_minus1, + key=next_key, + remaining_accesses=remaining_accesses, + squashed_dict=squashed_dict + DictAccess.SIZE, + big_keys=big_keys) +end diff --git a/src/starkware/cairo/lang/CMakeLists.txt b/src/starkware/cairo/lang/CMakeLists.txt new file mode 100644 index 0000000..695317c --- /dev/null +++ b/src/starkware/cairo/lang/CMakeLists.txt @@ -0,0 +1,61 @@ +add_subdirectory(builtins) +add_subdirectory(compiler) +add_subdirectory(scripts) +add_subdirectory(tracer) +add_subdirectory(vm) + +python_lib(cairo_version_lib + PREFIX starkware/cairo/lang + + FILES + VERSION + version.py +) + +if (NOT DEFINED CAIRO_PYTHON_INTERPRETER) + set(CAIRO_PYTHON_INTERPRETER python3.7) +endif() + +python_venv(cairo_lang_venv + PYTHON ${CAIRO_PYTHON_INTERPRETER} + LIBS + cairo_bootloader_generate_fact_lib + cairo_common_lib + cairo_compile_lib + cairo_hash_program_lib + cairo_run_lib + cairo_script_lib + ${CAIRO_LANG_VENV_ADDITIONAL_LIBS} +) + +python_venv(cairo_lang_package_venv + PYTHON python3.7 + LIBS + cairo_bootloader_generate_fact_lib + cairo_common_lib + cairo_compile_lib + cairo_hash_program_lib + cairo_run_lib + cairo_script_lib + sharp_client_lib + sharp_client_config_lib +) + +python_lib(cairo_instances_lib + PREFIX starkware/cairo/lang + + FILES + instances.py + ${CAIRO_INSTANCES_LIB_ADDITIONAL_FILES} + + LIBS + cairo_run_builtins_lib +) + + +python_lib(cairo_constants_lib + PREFIX starkware/cairo/lang + + FILES + cairo_constants.py +) diff --git a/src/starkware/cairo/lang/MANIFEST.in b/src/starkware/cairo/lang/MANIFEST.in new file mode 100644 index 0000000..f9bd145 --- /dev/null +++ b/src/starkware/cairo/lang/MANIFEST.in @@ -0,0 +1 @@ +include requirements.txt diff --git a/src/starkware/cairo/lang/README.md b/src/starkware/cairo/lang/README.md new file mode 100644 index 0000000..f17c481 --- /dev/null +++ b/src/starkware/cairo/lang/README.md @@ -0,0 +1,61 @@ +# Introduction + +[Cairo](https://cairo-lang.org/) is a programming language for writing provable programs. + +# Documentation + +The Cairo documentation consists of two parts: "Hello Cairo" and "How Cairo Works?". +Both parts can be found in https://cairo-lang.org/docs/. + +We recommend starting from [Setting up the environment](https://cairo-lang.org/docs/quickstart.html). + +# Installation instructions + +You should be able to download the python package zip file directly from +[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.0.3) +and install it using ``pip``. +See [Setting up the environment](https://cairo-lang.org/docs/quickstart.html). + +However, if you want to build it yourself, you can build it from the git repository. +It is recommended to run the build inside a docker (as explained below), +since it guarantees that all the dependencies +are installed. Alternatively, you can try following the commands in the +[docker file](https://github.com/starkware-libs/cairo-lang/blob/master/Dockerfile). + +## Building using the dockerfile + +*Note*: This section is relevant only if you wish to build the Cairo python-package yourself, +rather than downloading it. + +The root directory holds a dedicated Dockerfile, which automatically builds the package and runs +the unit tests on a simulated Ubuntu 18.04 environment. +You should have docker installed (see https://docs.docker.com/get-docker/). + +Clone the repository and initialize the git submodules using: + +```bash +> git clone git@github.com:starkware-libs/cairo-lang.git +> cd cairo-lang +> git submodule update --init +``` + +Build the docker image: + +```bash +> docker build --tag cairo . +``` + +If everything works, you should see + +```bash +Successfully tagged cairo:latest +``` + +Once the docker image is built, you can fetch the python package zip file using: + +```bash +> container_id=$(docker create cairo) +> docker cp ${container_id}:/app/cairo-lang-0.0.3.zip . +> docker rm -v ${container_id} +``` + diff --git a/src/starkware/cairo/lang/VERSION b/src/starkware/cairo/lang/VERSION new file mode 100644 index 0000000..bcab45a --- /dev/null +++ b/src/starkware/cairo/lang/VERSION @@ -0,0 +1 @@ +0.0.3 diff --git a/src/starkware/cairo/lang/__init__.py b/src/starkware/cairo/lang/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/starkware/cairo/lang/builtins/CMakeLists.txt b/src/starkware/cairo/lang/builtins/CMakeLists.txt new file mode 100644 index 0000000..f59e922 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/CMakeLists.txt @@ -0,0 +1,47 @@ +python_lib(cairo_run_builtins_lib + PREFIX starkware/cairo/lang/builtins + + FILES + checkpoints/instance_def.py + checkpoints/checkpoints_builtin_runner.py + hash/hash_builtin_runner.py + hash/instance_def.py + range_check/instance_def.py + range_check/range_check_builtin_runner.py + signature/instance_def.py + signature/signature_builtin_runner.py + + LIBS + cairo_relocatable + cairo_vm_lib + starkware_python_utils_lib +) + +python_lib(cairo_run_builtins_test_utils_lib + PREFIX starkware/cairo/lang/builtins + FILES + builtin_runner_test_utils.py + + LIBS + cairo_compile_lib + cairo_run_lib +) + +full_python_test(cairo_run_builtins_test + PREFIX starkware/cairo/lang/builtins + PYTHON python3.7 + TESTED_MODULES starkware/cairo/lang/builtins + + FILES + range_check/range_check_builtin_runner_test.py + signature/signature_builtin_runner_test.py + + LIBS + cairo_common_lib + cairo_compile_lib + cairo_run_lib + cairo_run_builtins_lib + cairo_run_builtins_test_utils_lib + starkware_python_test_utils_lib + pip_pytest +) diff --git a/src/starkware/cairo/lang/builtins/builtin_runner_test_utils.py b/src/starkware/cairo/lang/builtins/builtin_runner_test_utils.py new file mode 100644 index 0000000..e67c239 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/builtin_runner_test_utils.py @@ -0,0 +1,17 @@ +from starkware.cairo.lang.compiler.cairo_compile import compile_cairo +from starkware.cairo.lang.vm.cairo_runner import CairoRunner + +PRIME = 2**251 + 17 * 2**192 + 1 + + +def compile_and_run(code: str): + """ + Compiles the given code and runs it in the VM. + """ + program = compile_cairo(code, PRIME) + runner = CairoRunner(program, layout='small', proof_mode=False) + runner.initialize_segments() + end = runner.initialize_main_entrypoint() + runner.initialize_vm({}) + runner.run_until_pc(end) + runner.end_run() diff --git a/src/starkware/cairo/lang/builtins/checkpoints/checkpoints_builtin_runner.py b/src/starkware/cairo/lang/builtins/checkpoints/checkpoints_builtin_runner.py new file mode 100644 index 0000000..aaff0a4 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/checkpoints/checkpoints_builtin_runner.py @@ -0,0 +1,37 @@ +from typing import List + +from starkware.cairo.lang.builtins.checkpoints.instance_def import CELLS_PER_SAMPLE +from starkware.cairo.lang.vm.builtin_runner import SimpleBuiltinRunner +from starkware.python.math_utils import safe_div + + +class CheckpointsBuiltinRunner(SimpleBuiltinRunner): + def __init__(self, name: str, included: bool, sample_ratio: int): + self.sample_ratio = sample_ratio + self.samples: List = [] + super().__init__(name, included, sample_ratio, CELLS_PER_SAMPLE) + + def finalize_segments(self, runner): + memory = runner.vm.run_context.memory + memory[self.stop_ptr] = 0 + memory[self.stop_ptr + 1] = 0 + super().finalize_segments(runner) + + def get_used_cells_and_allocated_size(self, runner): + size = self.get_used_cells(runner) + return size, size + + def sample(self, step, pc, fp): + self.samples.append((step, pc, fp)) + + def relocate(self, relocate_value): + self.samples = [tuple(map(relocate_value, sample)) for sample in self.samples] + + def air_private_input(self, runner): + return {self.name: [ + { + 'index': safe_div(step, self.sample_ratio), + 'pc': hex(pc), + 'fp': hex(fp) + } + for step, pc, fp in self.samples]} diff --git a/src/starkware/cairo/lang/builtins/checkpoints/instance_def.py b/src/starkware/cairo/lang/builtins/checkpoints/instance_def.py new file mode 100644 index 0000000..eac4489 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/checkpoints/instance_def.py @@ -0,0 +1,11 @@ +import dataclasses + +# Each sample consists of 2 cells (required pc and required fp). +CELLS_PER_SAMPLE = 2 + + +@dataclasses.dataclass +class CheckpointsInstanceDef: + # Defines the ratio between the number of steps to the number of samples. + # For every sample_ratio steps, we have one sample. + sample_ratio: int diff --git a/src/starkware/cairo/lang/builtins/hash/hash_builtin_runner.py b/src/starkware/cairo/lang/builtins/hash/hash_builtin_runner.py new file mode 100644 index 0000000..4ff2953 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/hash/hash_builtin_runner.py @@ -0,0 +1,82 @@ +from typing import Any, Dict, Optional, Set + +from starkware.cairo.lang.builtins.hash.instance_def import CELLS_PER_HASH +from starkware.cairo.lang.vm.builtin_runner import BuiltinVerifier, SimpleBuiltinRunner +from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue +from starkware.python.math_utils import safe_div + + +class HashBuiltinRunner(SimpleBuiltinRunner): + def __init__(self, name: str, included: bool, ratio: int, hash_func): + super().__init__(name, included, ratio, CELLS_PER_HASH) + self.hash_func = hash_func + self.stop_ptr: Optional[RelocatableValue] = None + self.verified_addresses: Set[MaybeRelocatable] = set() + + def add_auto_deduction_rules(self, runner): + def rule(vm, addr, verified_addresses): + memory = vm.run_context.memory + if addr.offset % CELLS_PER_HASH != 2: + return + if addr in verified_addresses: + return + if addr - 1 not in memory or addr - 2 not in memory: + return + assert vm.is_integer_value(memory[addr - 2]), \ + f'{self.name} builtin: Expected integer at address {addr - 2}. ' + \ + f'Got: {memory[addr - 2]}.' + assert vm.is_integer_value(memory[addr - 1]), \ + f'{self.name} builtin: Expected integer at address {addr - 1}. ' + \ + f'Got: {memory[addr - 1]}.' + res = self.hash_func(memory[addr - 2], memory[addr - 1]) + verified_addresses.add(addr) + return res + + runner.vm.add_auto_deduction_rule(self.base.segment_index, rule, self.verified_addresses) + + def air_private_input(self, runner) -> Dict[str, Any]: + assert self.base is not None, 'Uninitialized self.base.' + res: Dict[int, Any] = {} + for addr, val in runner.vm_memory.items(): + if not isinstance(addr, RelocatableValue) or \ + addr.segment_index != self.base.segment_index: + continue + idx = addr.offset // CELLS_PER_HASH + typ = addr.offset % CELLS_PER_HASH + if typ == 2: + continue + + assert isinstance(val, int) + res.setdefault(idx, {'index': idx})['x' if typ == 0 else 'y'] = hex(val) + + for index, item in res.items(): + assert 'x' in item, f'Missing first input of {self.name} instance {index}.' + assert 'y' in item, f'Missing second input of {self.name} instance {index}.' + + return {self.name: sorted(res.values(), key=lambda item: item['index'])} + + def get_additional_data(self): + return [list(RelocatableValue.to_tuple(x)) for x in sorted(self.verified_addresses)] + + def extend_additional_data(self, data, relocate_callback, data_is_trusted=True): + if not data_is_trusted: + return + + for addr in data: + self.verified_addresses.add(relocate_callback(RelocatableValue.from_tuple(addr))) + + +class HashBuiltinVerifier(BuiltinVerifier): + def __init__(self, included: bool, ratio): + self.included = included + self.ratio = ratio + + def expected_stack(self, public_input): + if not self.included: + return [], [] + + addresses = public_input.memory_segments['pedersen'] + max_size = CELLS_PER_HASH * safe_div(public_input.n_steps, self.ratio) + assert 0 <= addresses.begin_addr <= addresses.stop_ptr <= \ + addresses.begin_addr + max_size < 2**64 + return [addresses.begin_addr], [addresses.stop_ptr] diff --git a/src/starkware/cairo/lang/builtins/hash/instance_def.py b/src/starkware/cairo/lang/builtins/hash/instance_def.py new file mode 100644 index 0000000..f95f713 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/hash/instance_def.py @@ -0,0 +1,23 @@ +import dataclasses +from typing import Optional + +# Each hash consists of 3 cells (two inputs and one output). +CELLS_PER_HASH = 3 + + +@dataclasses.dataclass +class PedersenInstanceDef: + # Defines the ratio between the number of steps to the number of pedersen instances. + # For every ratio steps, we have one instance. + ratio: int + + # Split to this many different components - for optimization. + repetitions: int + + # Size of hash. + element_height: int + element_bits: int + # Number of inputs for hash. + n_inputs: int + # The upper bound on the hash inputs. If None, the upper bound is 2^element_bits. + hash_limit: Optional[int] = None diff --git a/src/starkware/cairo/lang/builtins/range_check/instance_def.py b/src/starkware/cairo/lang/builtins/range_check/instance_def.py new file mode 100644 index 0000000..1e1dce4 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/range_check/instance_def.py @@ -0,0 +1,13 @@ +import dataclasses + +CELLS_PER_RANGE_CHECK = 1 + + +@dataclasses.dataclass +class RangeCheckInstanceDef: + # Defines the ratio between the number of steps to the number of range check instances. + # For every ratio steps, we have one instance. + ratio: int + # Number of 16-bit range checks that will be used for each instance of the builtin. + # For example, n_parts=8 defines the range [0, 2^128). + n_parts: int diff --git a/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner.py b/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner.py new file mode 100644 index 0000000..2ec1dea --- /dev/null +++ b/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner.py @@ -0,0 +1,87 @@ +from typing import Any, Dict, Optional, Tuple + +from starkware.cairo.lang.vm.builtin_runner import BuiltinVerifier, SimpleBuiltinRunner +from starkware.cairo.lang.vm.relocatable import RelocatableValue +from starkware.python.math_utils import safe_div + + +class RangeCheckBuiltinRunner(SimpleBuiltinRunner): + def __init__(self, included: bool, ratio, inner_rc_bound, n_parts): + super().__init__('range_check', included, ratio) + self.inner_rc_bound = inner_rc_bound + self.bound = inner_rc_bound ** n_parts + self.n_parts = n_parts + + def add_validation_rules(self, runner): + def rule(memory, addr): + value = memory[addr] + assert isinstance(value, int), \ + f'Range-check builtin: Expected value at address {addr} to be an integer. ' \ + f'Got: {value}.' + # The range check builtin asserts that 0 <= value < BOUND. + # For example, if the layout uses 8 16-bit range-checks per instance, + # bound will be 2**(16 * 8) = 2**128. + assert 0 <= value < self.bound, \ + f'Value {value}, in range check builtin {addr - self.base}, is out of range ' \ + f'[0, {self.bound}).' + return {addr} + + runner.vm.add_validation_rule(self.base.segment_index, rule) + + def air_private_input(self, runner) -> Dict[str, Any]: + assert self.base is not None, 'Uninitialized self.base.' + res: Dict[int, Any] = {} + for addr, val in runner.vm_memory.items(): + if not isinstance(addr, RelocatableValue) or \ + addr.segment_index != self.base.segment_index: + continue + idx = addr.offset + + assert isinstance(val, int) + res[idx] = {'index': idx, 'value': hex(val)} + + return {'range_check': sorted(res.values(), key=lambda item: item['index'])} + + def get_range_check_usage(self, runner) -> Optional[Tuple[int, int]]: + assert self.base is not None, 'Uninitialized self.base.' + rc_min = None + rc_max = None + for addr, val in runner.vm_memory.items(): + if not isinstance(addr, RelocatableValue) or \ + addr.segment_index != self.base.segment_index: + continue + + # Split val into n_parts parts. + for _ in range(self.n_parts): + part_val = val % self.inner_rc_bound + + if rc_min is None: + rc_min = rc_max = part_val + else: + rc_min = min(rc_min, part_val) + rc_max = max(rc_max, part_val) + val //= self.inner_rc_bound + if rc_min is None or rc_max is None: + return None + return rc_min, rc_max + + def get_used_perm_range_check_units(self, runner) -> int: + used_cells, _ = self.get_used_cells_and_allocated_size(runner) + # Each cell in the range check segment requires n_parts range check units. + return used_cells * self.n_parts + + +class RangeCheckBuiltinVerifier(BuiltinVerifier): + def __init__(self, included: bool, ratio): + self.included = included + self.ratio = ratio + + def expected_stack(self, public_input): + if not self.included: + return [], [] + + addresses = public_input.memory_segments['range_check'] + max_size = safe_div(public_input.n_steps, self.ratio) + assert 0 <= addresses.begin_addr <= addresses.stop_ptr <= \ + addresses.begin_addr + max_size < 2**64 + return [addresses.begin_addr], [addresses.stop_ptr] diff --git a/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner_test.py b/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner_test.py new file mode 100644 index 0000000..dc17018 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner_test.py @@ -0,0 +1,30 @@ +import pytest + +from starkware.cairo.lang.builtins.builtin_runner_test_utils import PRIME, compile_and_run +from starkware.cairo.lang.vm.vm import VmException + + +def test_validation_rules(): + CODE_FORMAT = """ +%builtins range_check + +func main(range_check_ptr) -> (range_check_ptr): + assert [range_check_ptr] = {value} + return (range_check_ptr=range_check_ptr + 1) +end +""" + + # Test valid values. + compile_and_run(CODE_FORMAT.format(value=0)) + compile_and_run(CODE_FORMAT.format(value=1)) + + with pytest.raises( + VmException, + match=f'Value {PRIME - 1}, in range check builtin 0, is out of range ' + r'\[0, {bound}\)'.format(bound=2**128)): + compile_and_run(CODE_FORMAT.format(value=-1)) + + with pytest.raises( + VmException, + match=f'Range-check builtin: Expected value at address 2:0 to be an integer. Got: 2:0'): + compile_and_run(CODE_FORMAT.format(value='range_check_ptr')) diff --git a/src/starkware/cairo/lang/builtins/signature/instance_def.py b/src/starkware/cairo/lang/builtins/signature/instance_def.py new file mode 100644 index 0000000..af4354e --- /dev/null +++ b/src/starkware/cairo/lang/builtins/signature/instance_def.py @@ -0,0 +1,18 @@ +import dataclasses + +# Each signature consists of 2 cells (a public key and a message). +CELLS_PER_SIGNATURE = 2 + + +@dataclasses.dataclass +class EcdsaInstanceDef: + # Defines the ratio between the number of steps to the number of ECDSA instances. + # For every ratio steps, we have one instance. + ratio: int + + # Split to this many different components - for optimization. + repetitions: int + + # Size of hash. + height: int + n_hash_bits: int diff --git a/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner.py b/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner.py new file mode 100644 index 0000000..cc1e266 --- /dev/null +++ b/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner.py @@ -0,0 +1,105 @@ +from typing import Any, Dict + +from starkware.cairo.lang.builtins.signature.instance_def import CELLS_PER_SIGNATURE +from starkware.cairo.lang.vm.builtin_runner import BuiltinVerifier, SimpleBuiltinRunner +from starkware.cairo.lang.vm.relocatable import RelocatableValue +from starkware.python.math_utils import safe_div + + +class SignatureBuiltinRunner(SimpleBuiltinRunner): + def __init__(self, name: str, included: bool, ratio, process_signature, verify_signature): + """ + 'process_signature' is a function that takes signatures as saved in 'signatures' and + returns a dict representing the signature in the format expected by the component used by + the runner. + It may also assert that the signature is valid. + """ + super().__init__(name, included, ratio, CELLS_PER_SIGNATURE) + self.process_signature = process_signature + self.verify_signature = verify_signature + + # A dict of address -> signature. + self.signatures: Dict = {} + + def add_validation_rules(self, runner): + def rule(memory, addr): + # A signature builtin instance consists of a pair of public key and message. + if addr.offset % CELLS_PER_SIGNATURE == 0 and addr + 1 in memory: + pubkey_addr = addr + msg_addr = addr + 1 + elif addr.offset % CELLS_PER_SIGNATURE == 1 and addr - 1 in memory: + pubkey_addr = addr - 1 + msg_addr = addr + else: + return set() + + pubkey = memory[pubkey_addr] + msg = memory[msg_addr] + assert isinstance(pubkey, int), \ + f'ECDSA builtin: Expected public key at address {pubkey_addr} to be an integer. ' \ + f'Got: {pubkey}.' + assert isinstance(msg, int), \ + f'ECDSA builtin: Expected message hash at address {msg_addr} to be an integer. ' \ + f'Got: {msg}.' + assert pubkey_addr in self.signatures, \ + f'Signature hint is missing for ECDSA builtin at address {pubkey_addr}. ' \ + "Add it using 'ecdsa_builtin.add_signature'." + + signature = self.signatures[pubkey_addr] + assert self.verify_signature(pubkey, msg, signature), \ + f'Signature {signature}, is invalid, with respect to the public key {pubkey}, ' \ + f'and the message hash {msg}.' + return {pubkey_addr, msg_addr} + + runner.vm.add_validation_rule(self.base.segment_index, rule) + + def air_private_input(self, runner) -> Dict[str, Any]: + res: Dict[int, Any] = {} + for (addr, signature) in self.signatures.items(): + addr_offset = addr - self.base + idx = safe_div(addr_offset, CELLS_PER_SIGNATURE) + pubkey = runner.vm_memory[addr] + msg = runner.vm_memory[addr + 1] + res[idx] = { + 'index': idx, + 'pubkey': hex(pubkey), + 'msg': hex(msg), + 'signature_input': self.process_signature(pubkey, msg, signature), + } + + return {self.name: sorted(res.values(), key=lambda item: item['index'])} + + def add_signature(self, addr, signature): + """ + This function should be used in Cairo hints. + """ + assert isinstance(addr, RelocatableValue), \ + f'Expected memory address to be relocatable value. Found: {addr}.' + assert addr.offset % CELLS_PER_SIGNATURE == 0, \ + f'Signature hint must point to the public key cell, not {addr}.' + self.signatures[addr] = signature + + def get_additional_data(self): + return [ + [list(RelocatableValue.to_tuple(addr)), signature] + for addr, signature in sorted(self.signatures.items())] + + def extend_additional_data(self, data, relocate_callback, data_is_trusted=True): + for addr, signature in data: + self.signatures[relocate_callback(RelocatableValue.from_tuple(addr))] = signature + + +class SignatureBuiltinVerifier(BuiltinVerifier): + def __init__(self, included: bool, ratio): + self.included = included + self.ratio = ratio + + def expected_stack(self, public_input): + if not self.included: + return [], [] + + addresses = public_input.memory_segments['signature'] + max_size = safe_div(public_input.n_steps, self.ratio) * CELLS_PER_SIGNATURE + assert 0 <= addresses.begin_addr <= addresses.stop_ptr <= \ + addresses.begin_addr + max_size < 2**64 + return [addresses.begin_addr], [addresses.stop_ptr] diff --git a/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner_test.py b/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner_test.py new file mode 100644 index 0000000..a7c656e --- /dev/null +++ b/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner_test.py @@ -0,0 +1,153 @@ +import dataclasses +from types import SimpleNamespace +from typing import Optional + +import pytest + +from starkware.cairo.lang.builtins.builtin_runner_test_utils import compile_and_run +from starkware.cairo.lang.vm.vm import VmException +from starkware.python.test_utils import maybe_raises + + +@dataclasses.dataclass +class SignatureCodeSections: + """ + Code sections relevant for using the signature builtin. + See code snippet structure below. + """ + hint: str + write_pubkey: str + write_msg: str + + +@dataclasses.dataclass +class SignatureExample: + code_sections: SignatureCodeSections + # Error message received by running the example code, in case there is any. + error_msg: Optional[str] + + +# Constants used for creating a code snippet using the signature builtin. +# See signature_builtin_runner_test.py. +SIG_PTR = 'ecdsa_ptr' +formats = SimpleNamespace( + hint_code_format='%{{ ecdsa_builtin.add_signature({addr}, {signature}) %}}', + pubkey_code_format=f'assert [{SIG_PTR} + SignatureBuiltin.pub_key] = {{pubkey}}', + msg_code_format=f'assert [{SIG_PTR} + SignatureBuiltin.message] = {{msg}}', +) + +# The address is used inside a hint. +VALID_ADDR = f'ids.{SIG_PTR}' +VALID_SIG = ( + 3086480810278599376317923499561306189851900463386393948998357832163236918254, + 598673427589502599949712887611119751108407514580626464031881322743364689811) +constants = SimpleNamespace( + valid_addr=VALID_ADDR, + invalid_addr=VALID_ADDR + ' + 1', + valid_sig=VALID_SIG, + invalid_sig=(VALID_SIG[0] + 1, VALID_SIG[1]), + valid_pubkey=1735102664668487605176656616876767369909409133946409161569774794110049207117, + valid_msg=2718, + invalid_pubkey_or_msg=SIG_PTR, +) + + +class SignatureTest: + """ + Aggregates test cases for the signature builtin runner. + A valid test case is added at initialization and further test cases are added based on the + valid case. + """ + + def __init__(self): + self.test_cases = {'valid': SignatureExample( + error_msg=None, + code_sections=SignatureCodeSections( + hint=formats.hint_code_format.format( + addr=constants.valid_addr, signature=constants.valid_sig), + write_pubkey=formats.pubkey_code_format.format( + pubkey=constants.valid_pubkey), + write_msg=formats.msg_code_format.format(msg=constants.valid_msg), + ) + )} + + def add_test_case(self, name: str, error_msg: Optional[str], **code_section_changes): + """ + Adds a new test case with the given error message, based on the valid case and the given + changes to it. + """ + self.test_cases[name] = SignatureExample( + code_sections=dataclasses.replace( + self.test_cases['valid'].code_sections, **code_section_changes), + error_msg=error_msg, + ) + + def get_test_cases(self): + return self.test_cases + + +# Signature code snippet structure. +CODE = """ +%builtins ecdsa +from starkware.cairo.common.cairo_builtins import SignatureBuiltin + +func main(ecdsa_ptr) -> (ecdsa_ptr): + {hint} + {write_pubkey} + {write_msg} + return(ecdsa_ptr=ecdsa_ptr + SignatureBuiltin.SIZE) +end +""" + +test = SignatureTest() +test.add_test_case( + name='invalid_signature_address', + error_msg='Signature hint must point to the public key cell, not 2:1.', + hint=formats.hint_code_format.format( + addr=constants.invalid_addr, signature=constants.valid_sig), +) + +test.add_test_case( + name='invalid_signature', + error_msg=( + r'Signature .* is invalid, with respect to the public key ' + '1735102664668487605176656616876767369909409133946409161569774794110049207117, ' + 'and the message hash 2718.'), + hint=formats.hint_code_format.format( + addr=constants.valid_addr, signature=constants.invalid_sig), +) + +test.add_test_case( + name='invalid_public_key', + error_msg='ECDSA builtin: Expected public key at address 2:0 to be an integer. Got: 2:0.', + write_pubkey=formats.pubkey_code_format.format(pubkey=constants.invalid_pubkey_or_msg), +) + +test.add_test_case( + name='invalid_message', + error_msg='ECDSA builtin: Expected message hash at address 2:1 to be an integer. Got: 2:0.', + write_msg=formats.msg_code_format.format(msg=constants.invalid_pubkey_or_msg), +) + +test.add_test_case( + name='missing_hint', + error_msg=( + 'Signature hint is missing for ECDSA builtin at address 2:0. ' + "Add it using 'ecdsa_builtin.add_signature'."), + hint='', +) + +# Missing public key or message would not cause a runtime error, but would fail the prover. +test.add_test_case(name='missing_public_key', error_msg=None, write_pubkey='') +test.add_test_case(name='missing_message', error_msg=None, write_msg='') + +test_cases = test.get_test_cases() + + +@pytest.mark.parametrize('case', test_cases.values(), ids=test_cases.keys()) +def test_validation_rules(case): + code = CODE.format(**dataclasses.asdict(case.code_sections)) + with maybe_raises( + expected_exception=VmException, error_message=case.error_msg, + escape_error_message=False): + compile_and_run(code) diff --git a/src/starkware/cairo/lang/cairo_cmake_rules.cmake b/src/starkware/cairo/lang/cairo_cmake_rules.cmake new file mode 100644 index 0000000..dfcd35b --- /dev/null +++ b/src/starkware/cairo/lang/cairo_cmake_rules.cmake @@ -0,0 +1,121 @@ +# Compiles a Cairo program. +# Usage example: +# cairo_compile(mytarget main_compiled.json main.cairo "--debug_info_with_source") +function(cairo_compile TARGET_NAME COMPILED_PROGRAM_NAME SOURCE_FILE COMPILE_FLAGS) + get_lib_info_file(STAMP_FILE cairo_lang_venv) + + set(COMPILED_PROGRAM "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_compiled.json") + + # Choose a file name for the Cairo dependencies of the compiled file. + set(COMPILE_DEPENDENCY_FILE + "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_compile_dependencies.cmake") + # If this is the first build, create an empty dependency file (this file will be overriden when + # cairo-compile is executed with the actual dependencies, using the --cairo_dependencies flag). + if(NOT EXISTS ${COMPILE_DEPENDENCY_FILE}) + file(WRITE ${COMPILE_DEPENDENCY_FILE} "") + endif() + # The following include() will populate the DEPENDENCIES variable with the Cairo files. + include(${COMPILE_DEPENDENCY_FILE}) + + add_custom_command( + OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${COMPILED_PROGRAM_NAME}" + COMMAND "${CMAKE_BINARY_DIR}/src/starkware/cairo/lang/cairo_lang_venv/bin/python" + "-m" "starkware.cairo.lang.compiler.cairo_compile" + "${CMAKE_CURRENT_SOURCE_DIR}/${SOURCE_FILE}" + "--output=${CMAKE_CURRENT_BINARY_DIR}/${COMPILED_PROGRAM_NAME}" + "--prime=3618502788666131213697322783095070105623107215331596699973092056135872020481" + "--cairo_path=${CMAKE_SOURCE_DIR}/src" + "--cairo_dependencies=${COMPILE_DEPENDENCY_FILE}" + "${COMPILE_FLAGS}" + DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/${SOURCE_FILE}" + ${COMPILE_DEPENDENCY_FILE} ${DEPENDENCIES} cairo_lang_venv ${STAMP_FILE} + COMMENT "Compiling ${CMAKE_CURRENT_SOURCE_DIR}/${SOURCE_FILE}" + ) + + add_custom_target(${TARGET_NAME} + ALL + DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/${COMPILED_PROGRAM_NAME}" + ) +endfunction() + +# Compiles and runs a Cairo file. +# ARTIFACTS is a ';'-separated list that may contain the following artifacts: +# * trace +# * public_input +function(cairo_compile_run TARGET_NAME FILENAME STEPS ARTIFACTS COMPILE_FLAGS RUN_FLAGS) + get_lib_info_file(STAMP_FILE cairo_lang_venv) + + set(ARTIFACT_LIST ${ARTIFACTS}) + + # Choose a file name for the python dependencies of cairo-run. + set(RUN_DEPENDENCY_FILE "${CMAKE_CURRENT_BINARY_DIR}/${FILENAME}_run_dependencies.cmake") + # If this is the first build, create an empty dependency file (this file will be overriden when + # cairo-run is executed with the actual dependencies, using the --python_dependencies flag). + if(NOT EXISTS ${RUN_DEPENDENCY_FILE}) + file(WRITE ${RUN_DEPENDENCY_FILE} "") + endif() + # The following include() will populate the DEPENDENCIES variable with the python modules used + # by cairo-compile and cairo-run. + include(${RUN_DEPENDENCY_FILE}) + + if ("trace" IN_LIST ARTIFACT_LIST) + set(MEMORY_FILE "${CMAKE_CURRENT_BINARY_DIR}/${FILENAME}_memory.bin") + set(TRACE_FILE "${CMAKE_CURRENT_BINARY_DIR}/${FILENAME}_trace.bin") + set(TRACE_HEADER "${TRACE_FILE}.h") + set(MEMORY_HEADER "${MEMORY_FILE}.h") + set(GENERATE_TRACE "--trace_file=${TRACE_FILE}") + set(GENERATE_MEMORY "--memory_file=${MEMORY_FILE}") + endif() + if ("public_input" IN_LIST ARTIFACT_LIST) + set(PUBLIC_INPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${FILENAME}_public_input.json") + set(PUBLIC_INPUT_HEADER "${PUBLIC_INPUT_FILE}.h") + set(GENERATE_PUBLIC_INPUT "--air_public_input=${PUBLIC_INPUT_FILE}") + endif() + + cairo_compile( + "${TARGET_NAME}_compile" + "${TARGET_NAME}_compiled.json" + "${FILENAME}.cairo" + "${COMPILE_FLAGS}") + + add_custom_command( + OUTPUT "${MEMORY_FILE}" "${TRACE_FILE}" "${PUBLIC_INPUT_FILE}" + COMMAND "${CMAKE_BINARY_DIR}/src/starkware/cairo/lang/cairo_lang_venv/bin/python" + "-m" "starkware.cairo.lang.vm.cairo_run" + "--program=${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_compiled.json" + "--steps=${STEPS}" + "--python_dependencies=${RUN_DEPENDENCY_FILE}" + "--proof_mode" + "${RUN_FLAGS}" + "${GENERATE_MEMORY}" + "${GENERATE_TRACE}" + "${GENERATE_PUBLIC_INPUT}" + DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/${FILENAME}.cairo" + "${TARGET_NAME}_compiled.json" + ${RUN_DEPENDENCY_FILE} ${DEPENDENCIES} cairo_lang_venv ${VENV_STAMP} + COMMENT "Executing ${CMAKE_CURRENT_SOURCE_DIR}/${FILENAME}.cairo" + ) + if ("trace" IN_LIST ARTIFACT_LIST) + generate_cpp_resource( + ${TRACE_FILE} + ${TRACE_HEADER} + ${FILENAME}_trace + ) + generate_cpp_resource( + ${MEMORY_FILE} + ${MEMORY_HEADER} + ${FILENAME}_memory + ) + endif() + if ("public_input" IN_LIST ARTIFACT_LIST) + generate_cpp_resource( + ${PUBLIC_INPUT_FILE} + ${PUBLIC_INPUT_HEADER} + ${FILENAME}_public_input + ) + endif() + + add_custom_target(${TARGET_NAME} + DEPENDS "${TRACE_HEADER}" "${MEMORY_HEADER}" "${PUBLIC_INPUT_HEADER}" + ) +endfunction(cairo_compile_run) diff --git a/src/starkware/cairo/lang/cairo_constants.py b/src/starkware/cairo/lang/cairo_constants.py new file mode 100644 index 0000000..84f2145 --- /dev/null +++ b/src/starkware/cairo/lang/cairo_constants.py @@ -0,0 +1 @@ +DEFAULT_PRIME = 2**251 + 17 * 2**192 + 1 diff --git a/src/starkware/cairo/lang/compiler/CMakeLists.txt b/src/starkware/cairo/lang/compiler/CMakeLists.txt new file mode 100644 index 0000000..b0fced4 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/CMakeLists.txt @@ -0,0 +1,150 @@ +python_lib(cairo_compile_lib + PREFIX starkware/cairo/lang/compiler + FILES + __init__.py + assembler.py + ast/__init__.py + ast/aliased_identifier.py + ast/arguments.py + ast/bool_expr.py + ast/cairo_types.py + ast/code_elements.py + ast/expr.py + ast/formatting_utils.py + ast/instructions.py + ast/module.py + ast/node.py + ast/notes.py + ast/rvalue.py + ast/types.py + ast/visitor.py + cairo_compile.py + cairo_format.py + cairo.ebnf + constants.py + const_expr_checker.py + debug_info.py + encode.py + error_handling.py + expression_evaluator.py + expression_simplifier.py + expression_transformer.py + fields.py + identifier_definition.py + identifier_manager.py + identifier_manager_field.py + identifier_utils.py + import_loader.py + instruction_builder.py + instruction.py + location_utils.py + module_reader.py + offset_reference.py + parser_transformer.py + parser.py + preprocessor/compound_expressions.py + preprocessor/dependency_graph.py + preprocessor/flow.py + preprocessor/identifier_aware_visitor.py + preprocessor/identifier_collector.py + preprocessor/local_variables.py + preprocessor/preprocessor_error.py + preprocessor/preprocessor_utils.py + preprocessor/preprocessor.py + preprocessor/reg_tracking.py + preprocessor/struct_collector.py + preprocessor/unique_labels.py + program.py + references.py + resolve_search_result.py + scoped_name.py + substitute_identifiers.py + type_casts.py + type_system_visitor.py + type_system.py + + LIBS + cairo_constants_lib + cairo_version_lib + starkware_expression_string_lib + starkware_python_utils_lib + pip_marshmallow_dataclass + pip_marshmallow_enum + pip_marshmallow_oneofschema + pip_marshmallow + pip_lark_parser +) + +python_exe(cairo_compile_exe + VENV cairo_lang_venv + MODULE starkware.cairo.lang.compiler.cairo_compile +) + +python_venv(cairo_format_venv + PYTHON python3.7 + LIBS + cairo_compile_lib +) + +python_exe(cairo_format + VENV cairo_format_venv + MODULE starkware.cairo.lang.compiler.cairo_format +) + +python_lib(cairo_compile_test_utils_lib + PREFIX starkware/cairo/lang/compiler + FILES + preprocessor/preprocessor_test_utils.py + test_utils.py + + LIBS + cairo_compile_lib + pip_pytest +) + +full_python_test(cairo_compile_test + PREFIX starkware/cairo/lang/compiler + PYTHON python3.7 + TESTED_MODULES starkware/cairo/lang/compiler + + FILES + assembler_test.py + ast_objects_test.py + ast/formatting_utils_test.py + cairo_compile_test.py + encode_test.py + error_handling_test.py + expression_evaluator_test.py + expression_simplifier_test.py + identifier_definition_test.py + identifier_manager_field_test.py + identifier_manager_test.py + identifier_utils_test.py + import_loader_test.py + instruction_builder_test.py + instruction_test.py + module_reader_test.py + offset_reference_test.py + parser_errors_test.py + parser_test_utils.py + parser_test.py + preprocessor/compound_expressions_test.py + preprocessor/dependency_graph_test.py + preprocessor/flow_test.py + preprocessor/identifier_collector_test.py + preprocessor/local_variables_test.py + preprocessor/preprocessor_test.py + preprocessor/reg_tracking_test.py + preprocessor/struct_collector_test.py + preprocessor/unique_labels_test.py + references_test.py + resolve_search_result_test.py + scoped_name_test.py + type_casts_test.py + type_system_visitor_test.py + + LIBS + cairo_compile_lib + cairo_compile_test_utils_lib + pip_pytest +) diff --git a/src/starkware/cairo/lang/compiler/__init__.py b/src/starkware/cairo/lang/compiler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/starkware/cairo/lang/compiler/assembler.py b/src/starkware/cairo/lang/compiler/assembler.py new file mode 100644 index 0000000..65a05c4 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/assembler.py @@ -0,0 +1,49 @@ +from typing import Dict, List + +from starkware.cairo.lang.compiler.debug_info import DebugInfo, HintLocation, InstructionLocation +from starkware.cairo.lang.compiler.encode import encode_instruction +from starkware.cairo.lang.compiler.instruction_builder import build_instruction +from starkware.cairo.lang.compiler.preprocessor.preprocessor import PreprocessedProgram +from starkware.cairo.lang.compiler.program import CairoHint, Program +from starkware.cairo.lang.compiler.scoped_name import ScopedName + + +def assemble( + preprocessed_program: PreprocessedProgram, main_scope: ScopedName = ScopedName(), + add_debug_info: bool = False, file_contents_for_debug_info: Dict[str, str] = {}) -> Program: + data: List[int] = [] + hints: Dict[int, CairoHint] = {} + debug_info = DebugInfo(instruction_locations={}, file_contents=file_contents_for_debug_info) \ + if add_debug_info else None + + for inst in preprocessed_program.instructions: + if inst.hint: + hints[len(data)] = CairoHint( + code=inst.hint.hint_code, + accessible_scopes=inst.accessible_scopes, + flow_tracking_data=inst.flow_tracking_data) + if debug_info is not None and inst.instruction.location is not None: + hint_location = None + if inst.hint is not None and inst.hint.location is not None: + hint_location = HintLocation( + location=inst.hint.location, + n_prefix_newlines=inst.hint.n_prefix_newlines, + ) + debug_info.instruction_locations[len(data)] = \ + InstructionLocation( + inst=inst.instruction.location, + hint=hint_location, + accessible_scopes=inst.accessible_scopes, + flow_tracking_data=inst.flow_tracking_data) + data += [word for word in encode_instruction( + build_instruction(inst.instruction), prime=preprocessed_program.prime)] + + return Program( + prime=preprocessed_program.prime, + data=data, + hints=hints, + main_scope=main_scope, + identifiers=preprocessed_program.identifiers, + builtins=preprocessed_program.builtins, + reference_manager=preprocessed_program.reference_manager, + debug_info=debug_info) diff --git a/src/starkware/cairo/lang/compiler/assembler_test.py b/src/starkware/cairo/lang/compiler/assembler_test.py new file mode 100644 index 0000000..d0dc430 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/assembler_test.py @@ -0,0 +1,57 @@ +import pytest + +from starkware.cairo.lang.compiler.identifier_definition import ConstDefinition, LabelDefinition +from starkware.cairo.lang.compiler.identifier_manager import ( + IdentifierManager, MissingIdentifierError) +from starkware.cairo.lang.compiler.preprocessor.flow import ReferenceManager +from starkware.cairo.lang.compiler.program import Program +from starkware.cairo.lang.compiler.scoped_name import ScopedName + + +def test_main_scope(): + identifiers = IdentifierManager.from_dict({ + ScopedName.from_string('a.b'): ConstDefinition(value=1), + ScopedName.from_string('x.y.z'): ConstDefinition(value=2), + }) + reference_manager = ReferenceManager() + + program = Program( + prime=0, data=[], hints={}, builtins=[], main_scope=ScopedName.from_string('a'), + identifiers=identifiers, reference_manager=reference_manager) + + # Check accessible identifiers. + assert program.get_identifier('b', ConstDefinition).value == 1 + + # Ensure inaccessible identifiers. + with pytest.raises(MissingIdentifierError, match="Unknown identifier 'a'."): + program.get_identifier('a.b', ConstDefinition) + + with pytest.raises(MissingIdentifierError, match="Unknown identifier 'x'."): + program.get_identifier('x.y', ConstDefinition) + + with pytest.raises(MissingIdentifierError, match="Unknown identifier 'y'."): + program.get_identifier('y', ConstDefinition) + + # Full name lookup. + assert program.get_identifier('a.b', ConstDefinition, full_name_lookup=True).value == 1 + assert program.get_identifier('x.y.z', ConstDefinition, full_name_lookup=True).value == 2 + + +def test_program_start_property(): + identifiers = IdentifierManager.from_dict({ + ScopedName.from_string('some.main.__start__'): LabelDefinition(3), + }) + reference_manager = ReferenceManager() + main_scope = ScopedName.from_string('some.main') + + # The label __start__ is in identifiers. + program = Program( + prime=0, data=[], hints={}, builtins=[], main_scope=main_scope, identifiers=identifiers, + reference_manager=reference_manager) + assert program.start == 3 + + # The label __start__ is not in identifiers. + program = Program( + prime=0, data=[], hints={}, builtins=[], main_scope=main_scope, + identifiers=IdentifierManager(), reference_manager=reference_manager) + assert program.start == 0 diff --git a/src/starkware/cairo/lang/compiler/ast/__init__.py b/src/starkware/cairo/lang/compiler/ast/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/starkware/cairo/lang/compiler/ast/aliased_identifier.py b/src/starkware/cairo/lang/compiler/ast/aliased_identifier.py new file mode 100644 index 0000000..4666b75 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/aliased_identifier.py @@ -0,0 +1,25 @@ +import dataclasses +from typing import Optional, Sequence + +from starkware.cairo.lang.compiler.ast.expr import ExprIdentifier +from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.error_handling import Location + + +@dataclasses.dataclass +class AliasedIdentifier(AstNode): + orig_identifier: ExprIdentifier + local_name: Optional[ExprIdentifier] + location: Optional[Location] = LocationField + + def format(self): + return f'{self.orig_identifier.format()}' + \ + (f' as {self.local_name.format()}' if self.local_name else '') + + @property + def identifier(self): + return self.local_name if self.local_name is not None else self.orig_identifier + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.orig_identifier, self.local_name] diff --git a/src/starkware/cairo/lang/compiler/ast/arguments.py b/src/starkware/cairo/lang/compiler/ast/arguments.py new file mode 100644 index 0000000..edf6989 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/arguments.py @@ -0,0 +1,23 @@ +import dataclasses +from typing import List, Optional, Sequence + +from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.ast.notes import NoteListField, Notes +from starkware.cairo.lang.compiler.ast.types import TypedIdentifier +from starkware.cairo.lang.compiler.error_handling import Location + + +@dataclasses.dataclass +class IdentifierList(AstNode): + identifiers: List[TypedIdentifier] + notes: List[Notes] = NoteListField # type: ignore + location: Optional[Location] = LocationField + + def get_particles(self): + for note in self.notes: + note.assert_no_comments() + return [x.format() for x in self.identifiers] + + def get_children(self) -> Sequence[Optional[AstNode]]: + return self.identifiers diff --git a/src/starkware/cairo/lang/compiler/ast/bool_expr.py b/src/starkware/cairo/lang/compiler/ast/bool_expr.py new file mode 100644 index 0000000..ca6769b --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/bool_expr.py @@ -0,0 +1,22 @@ +import dataclasses +from typing import Optional, Sequence + +from starkware.cairo.lang.compiler.ast.expr import Expression +from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.error_handling import Location + + +@dataclasses.dataclass +class BoolExpr(AstNode): + a: Expression + b: Expression + eq: bool + location: Optional[Location] = LocationField + + def get_particles(self): + relation = '==' if self.eq else '!=' + return [f'{self.a.format()} {relation} ', self.b.format()] + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.a, self.b] diff --git a/src/starkware/cairo/lang/compiler/ast/cairo_types.py b/src/starkware/cairo/lang/compiler/ast/cairo_types.py new file mode 100644 index 0000000..0af22dd --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/cairo_types.py @@ -0,0 +1,95 @@ +import dataclasses +from abc import abstractmethod +from enum import Enum, auto +from typing import List, Optional, Sequence + +from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.error_handling import Location +from starkware.cairo.lang.compiler.scoped_name import ScopedName + + +class CairoType(AstNode): + location: Optional[Location] + + @abstractmethod + def format(self) -> str: + """ + Returns a representation of the type as a string. + """ + + def get_pointer_type(self) -> 'CairoType': + """ + Returns a type of a pointer to the current type. + """ + return TypePointer(pointee=self, location=self.location) + + +@dataclasses.dataclass +class TypeFelt(CairoType): + location: Optional[Location] = LocationField + + def format(self): + return 'felt' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class TypePointer(CairoType): + pointee: CairoType + location: Optional[Location] = LocationField + + def format(self): + return f'{self.pointee.format()}*' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.pointee] + + +@dataclasses.dataclass +class TypeStruct(CairoType): + scope: ScopedName + # Indicates whether scope refers to the fully resolved name. + is_fully_resolved: bool + location: Optional[Location] = LocationField + + def format(self): + return str(self.scope) + + @property + def resolved_scope(self): + """ + Verifies that is_fully_resolved=True and returns scope. + """ + assert self.is_fully_resolved, 'Type is expected to be fully resolved at this point.' + return self.scope + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class TypeTuple(CairoType): + """ + Type for a tuple. + """ + members: List[CairoType] + location: Optional[Location] = LocationField + + def format(self): + member_formats = [member.format() for member in self.members] + return f"({', '.join(member_formats)})" + + def get_children(self) -> Sequence[Optional[AstNode]]: + return self.members + + +class CastType(Enum): + # When an explicit cast occurs using 'cast(*, *)'. + EXPLICIT = 0 + # When unpacking occurs (e.g., 'let (x : T) = foo()'). + UNPACKING = auto() + # When a variable is initialized (e.g., 'tempvar x : T = 5'). + ASSIGN = auto() diff --git a/src/starkware/cairo/lang/compiler/ast/code_elements.py b/src/starkware/cairo/lang/compiler/ast/code_elements.py new file mode 100644 index 0000000..e94bed0 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/code_elements.py @@ -0,0 +1,648 @@ +import dataclasses +from abc import abstractmethod +from typing import List, Optional, Sequence + +from starkware.cairo.lang.compiler.ast.aliased_identifier import AliasedIdentifier +from starkware.cairo.lang.compiler.ast.arguments import IdentifierList +from starkware.cairo.lang.compiler.ast.bool_expr import BoolExpr +from starkware.cairo.lang.compiler.ast.expr import ExprAssignment, Expression, ExprIdentifier +from starkware.cairo.lang.compiler.ast.formatting_utils import ( + INDENTATION, LocationField, ParticleFormattingConfig, create_particle_sublist, + particles_in_lines) +from starkware.cairo.lang.compiler.ast.instructions import InstructionAst +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.ast.notes import NoteListField, Notes +from starkware.cairo.lang.compiler.ast.rvalue import Rvalue, RvalueCall, RvalueFuncCall +from starkware.cairo.lang.compiler.ast.types import TypedIdentifier +from starkware.cairo.lang.compiler.error_handling import Location +from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.python.utils import indent + + +class CodeElement(AstNode): + @abstractmethod + def format(self, allowed_line_length): + """ + Formats the code element, without exceeding a line length of `allowed_line_length`. + """ + + +@dataclasses.dataclass +class CodeElementInstruction(CodeElement): + instruction: InstructionAst + + def get_particles(self): + return [self.instruction.format()] + + def format(self, allowed_line_length): + return self.instruction.format() + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.instruction] + + +@dataclasses.dataclass +class CodeElementConst(CodeElement): + identifier: ExprIdentifier + expr: Expression + + def format(self, allowed_line_length): + return f'const {self.identifier.format()} = {self.expr.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.identifier, self.expr] + + +@dataclasses.dataclass +class CodeElementMember(CodeElement): + typed_identifier: TypedIdentifier + + def format(self, allowed_line_length): + return f'member {self.typed_identifier.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.typed_identifier] + + +@dataclasses.dataclass +class CodeElementReference(CodeElement): + typed_identifier: TypedIdentifier + expr: Expression + + def format(self, allowed_line_length): + return f'let {self.typed_identifier.format()} = {self.expr.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.typed_identifier, self.expr] + + +@dataclasses.dataclass +class CodeElementLocalVariable(CodeElement): + """ + Represents a statement of the form: + local x [: expr_type] = [expr] + + Both the expr_type and the initialization expr are optional. + """ + typed_identifier: TypedIdentifier + expr: Optional[Expression] + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + assignment = '' if self.expr is None else f' = {self.expr.format()}' + return f'local {self.typed_identifier.format()}{assignment}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.typed_identifier, self.expr] + + +@dataclasses.dataclass +class CodeElementTemporaryVariable(CodeElement): + """ + Represents a statement of the form: + tempvar x = expr. + """ + typed_identifier: TypedIdentifier + expr: Expression + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + return f'tempvar {self.typed_identifier.format()} = {self.expr.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.typed_identifier, self.expr] + + +@dataclasses.dataclass +class CodeElementCompoundAssertEq(CodeElement): + """ + Represents the statement "assert a = b" for two (compound) expressions a, b. + Unlike AssertEqInstruction, a CodeElementCompoundAssertEq may translate to a few instructions + to deal with expressions which contain more than one operation. + """ + a: Expression + b: Expression + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + return f'assert {self.a.format()} = {self.b.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.a, self.b] + + +@dataclasses.dataclass +class CodeElementStaticAssert(CodeElement): + a: Expression + b: Expression + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + return f'static_assert {self.a.format()} == {self.b.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.a, self.b] + + +@dataclasses.dataclass +class CodeElementReturn(CodeElement): + """ + Represents a statement of the form: + return ([ident=]expr, ...). + """ + exprs: List[ExprAssignment] + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + expr_codes = [x.format() for x in self.exprs] + particles = ['return (', create_particle_sublist(expr_codes, ')')] + + return particles_in_lines( + particles=particles, + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, + line_indent=INDENTATION, + one_per_line=True)) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return self.exprs + + +@dataclasses.dataclass +class CodeElementTailCall(CodeElement): + """ + Represents a statement of the form: + return func_ident([ident=]expr, ...). + """ + func_call: RvalueFuncCall + location: Optional[Location] = LocationField + + def get_particles(self): + particales = self.func_call.get_particles() + return ['return ' + particales[0]] + particales[1:] + + def format(self, allowed_line_length): + return particles_in_lines( + particles=self.get_particles(), + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, + line_indent=INDENTATION, + one_per_line=True)) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.func_call] + + +@dataclasses.dataclass +class CodeElementFuncCall(CodeElement): + """ + Represents a statement of the form: + func_ident([ident=]expr, ...). + """ + func_call: RvalueFuncCall + + def get_particles(self): + return self.func_call.get_particles() + + def format(self, allowed_line_length): + return self.func_call.format(allowed_line_length) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.func_call] + + +@dataclasses.dataclass +class CodeElementReturnValueReference(CodeElement): + """ + Represents one of the references below. + let x [: type] = func(...) + let x [: type] = call func + let x [: type] = call rel 5 + where: + 'x [: type]' is the 'typed_identifier' + 'func(...)' is the 'func_call'. + """ + typed_identifier: TypedIdentifier + func_call: RvalueCall + + def format(self, allowed_line_length): + call_particles = self.func_call.get_particles() + first_particle = f'let {self.typed_identifier.format()} = ' + call_particles[0] + + return particles_in_lines( + particles=[first_particle] + call_particles[1:], + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, + line_indent=INDENTATION, + one_per_line=True)) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.typed_identifier, self.func_call] + + +@dataclasses.dataclass +class CodeElementUnpackBinding(CodeElement): + """ + Represents return value unpacking statement of the form: + let (a, b, c) = func(...) + where: + '(a, b, c)' is the 'unpacking_list' + 'func(...)' is the 'rvalue'. + """ + unpacking_list: IdentifierList + rvalue: Rvalue + + def format(self, allowed_line_length): + particles = self.rvalue.get_particles() + + end_particle = ') = ' + particles[0] + particles = ['let ('] + \ + create_particle_sublist(self.unpacking_list.get_particles(), end_particle) + \ + particles[1:] + + return particles_in_lines( + particles=particles, + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, + line_indent=INDENTATION, + one_per_line=True)) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.unpacking_list, self.rvalue] + + +@dataclasses.dataclass +class CodeElementLabel(CodeElement): + identifier: ExprIdentifier + + def format(self, allowed_line_length): + return f'{self.identifier.format()}:' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.identifier] + + +@dataclasses.dataclass +class CodeElementHint(CodeElement): + hint_code: str + # The number of new lines following the "%{" symbol. + n_prefix_newlines: int + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + if self.hint_code == '': + return '%{\n%}' + if '\n' not in self.hint_code: + # One liner. + return f'%{{ {self.hint_code} %}}' + code = indent(self.hint_code, INDENTATION) + return f'%{{\n{code}\n%}}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class CodeElementEmptyLine(CodeElement): + def format(self, allowed_line_length): + return '' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class CommentedCodeElement(AstNode): + code_elm: CodeElement + comment: Optional[str] + + def format(self, allowed_line_length): + elm_str = self.code_elm.format(allowed_line_length=allowed_line_length) + comment_str = f'#{self.comment}' if self.comment is not None else '' + separator = ' ' if elm_str != '' and comment_str != '' else '' + return elm_str + separator + comment_str.rstrip() + + def fix_comment_spaces(self, allow_additional_comment_spaces: bool): + """ + Comments should start with exactly one space after '#' except for some cases (in which + allow_additional_comment_spaces=True). + Returns a copy of this instance with a fixed comment. + """ + comment = self.comment + + if comment is None: + return self + + if not allow_additional_comment_spaces: + comment = comment.strip() + if not comment.startswith(' '): + comment = ' ' + comment + + return CommentedCodeElement(code_elm=self.code_elm, comment=comment) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.code_elm] + + +@dataclasses.dataclass +class CodeBlock(AstNode): + code_elements: List[CommentedCodeElement] + + def format(self, allowed_line_length): + code_elements = remove_redundant_empty_lines(self.code_elements) + code_elements = add_empty_lines_before_labels(code_elements) + code_elements = fix_comment_spaces(code_elements) + + return ''.join(f'{code_elm.format(allowed_line_length)}\n' for code_elm in code_elements) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return self.code_elements + + +@dataclasses.dataclass +class CodeElementScoped(CodeElement): + """ + Represents a list of code elements that should be handled inside a scope. + This class does not appear naturally in the parsed AST. + """ + scope: ScopedName + code_elements: List[CodeElement] + + def format(self, allowed_line_length): + raise NotImplementedError(f'Formatting {type(self).__name__} is not supported.') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return self.code_elements + + +@dataclasses.dataclass +class CodeElementFunction(CodeElement): + """ + Represents either a 'func', 'namespace' or 'struct' statement. + For example: + func foo(x, y) -> (z, w): + return (z=x, w=y) + end + """ + # The type of the code element. Either 'func', 'namespace' or 'struct'. + element_type: str + identifier: ExprIdentifier + arguments: IdentifierList + implicit_arguments: Optional[IdentifierList] + returns: Optional[IdentifierList] + code_block: CodeBlock + decorators: List[ExprIdentifier] + + ARGUMENT_SCOPE = ScopedName.from_string('Args') + IMPLICIT_ARGUMENT_SCOPE = ScopedName.from_string('ImplicitArgs') + RETURN_SCOPE = ScopedName.from_string('Return') + + @property + def name(self): + return self.identifier.name + + def format(self, allowed_line_length): + code = self.code_block.format(allowed_line_length=allowed_line_length - INDENTATION) + code = indent(code, INDENTATION) + if self.element_type in ['struct', 'namespace']: + particles = [f'{self.element_type} {self.name}:'] + else: + if self.implicit_arguments is not None: + first_particle_suffix = '{' + implicit_args_particles = [ + create_particle_sublist(self.implicit_arguments.get_particles(), '}(')] + else: + first_particle_suffix = '(' + implicit_args_particles = [] + + if self.returns is not None: + particles = [ + f'{self.element_type} {self.name}{first_particle_suffix}', + *implicit_args_particles, + create_particle_sublist(self.arguments.get_particles(), ') -> ('), + create_particle_sublist(self.returns.get_particles(), '):')] + else: + particles = [ + f'{self.element_type} {self.name}{first_particle_suffix}', + *implicit_args_particles, + create_particle_sublist(self.arguments.get_particles(), '):')] + + decorators = ''.join(f'@{decorator.format()}\n' for decorator in self.decorators) + header = particles_in_lines( + particles=particles, + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, + line_indent=INDENTATION * 2)) + return f'{decorators}{header}\n{code}end' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [ + self.identifier, self.arguments, self.implicit_arguments, self.returns, self.code_block] + + +@dataclasses.dataclass +class CodeElementWith(CodeElement): + identifiers: List[AliasedIdentifier] + code_block: CodeBlock + + def format(self, allowed_line_length): + identifier_list_str = ', '.join(identifier.format() for identifier in self.identifiers) + inner_code = self.code_block.format(allowed_line_length=allowed_line_length - INDENTATION) + inner_code = indent(inner_code, INDENTATION) + return f'with {identifier_list_str}:\n{inner_code}end' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [*self.identifiers, self.code_block] + + +@dataclasses.dataclass +class CodeElementIf(CodeElement): + condition: BoolExpr + main_code_block: CodeBlock + else_code_block: Optional[CodeBlock] + label_neq: Optional[str] = None + label_end: Optional[str] = None + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + cond_particles = ['if ', *self.condition.get_particles()] + cond_particles[-1] = cond_particles[-1] + ':' + code = particles_in_lines( + particles=cond_particles, + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, + line_indent=INDENTATION)) + main_code = self.main_code_block.format( + allowed_line_length=allowed_line_length - INDENTATION) + main_code = indent(main_code, INDENTATION) + code += f'\n{main_code}' + if self.else_code_block is not None: + code += f'else:' + else_code = self.else_code_block.format( + allowed_line_length=allowed_line_length - INDENTATION) + else_code = indent(else_code, INDENTATION) + code += f'\n{else_code}' + code += 'end' + return code + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.condition, self.main_code_block, self.else_code_block] + + +class Directive(AstNode): + @abstractmethod + def format(self): + pass + + +@dataclasses.dataclass +class BuiltinsDirective(Directive): + builtins: List[str] + location: Optional[Location] = LocationField + + def format(self): + return f'%builtins {" ".join(self.builtins)}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class CodeElementDirective(CodeElement): + directive: Directive + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + return self.directive.format() + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.directive] + + +@dataclasses.dataclass +class CodeElementImport(CodeElement): + path: ExprIdentifier + import_items: List[AliasedIdentifier] + notes: List[Notes] = NoteListField # type: ignore + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + for note in self.notes: + note.assert_no_comments() + + items = [item.format() for item in self.import_items] + prefix = f'from {self.path.format()} import ' + one_liner = prefix + ', '.join(items) + + if len(one_liner) <= allowed_line_length: + return one_liner + + particles = [f'{prefix}(', create_particle_sublist(items, ')')] + return particles_in_lines( + particles=particles, + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, + line_indent=INDENTATION, + one_per_line=False)) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.path, *self.import_items] + + +@dataclasses.dataclass +class CodeElementAllocLocals(CodeElement): + """ + Represents a statement of the form "alloc_locals". + """ + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + return 'alloc_locals' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +def is_empty_line(code_element: CommentedCodeElement): + return isinstance(code_element.code_elm, CodeElementEmptyLine) and code_element.comment is None + + +def is_comment_line(code_element: CommentedCodeElement): + return isinstance(code_element.code_elm, CodeElementEmptyLine) and \ + code_element.comment is not None + + +def remove_redundant_empty_lines( + code_elements: List[CommentedCodeElement]) -> List[CommentedCodeElement]: + """ + Returns a new list of code elements where redundant empty lines are removed. + Redundant empty lines are empty lines which are after: + 1. Empty lines. + 2. Labels. + or at the end of the list. + """ + new_code_elements = [] + skip_empty_lines = True + for code_elm in code_elements: + if is_empty_line(code_elm): + # Empty line. + if skip_empty_lines: + continue + skip_empty_lines = True + elif isinstance(code_elm.code_elm, CodeElementLabel): + skip_empty_lines = True + else: + skip_empty_lines = False + new_code_elements.append(code_elm) + + while len(new_code_elements) > 0 and is_empty_line(new_code_elements[-1]): + new_code_elements.pop() + + return new_code_elements + + +def add_empty_lines_before_labels( + code_elements: List[CommentedCodeElement]) -> List[CommentedCodeElement]: + """ + Makes sure there is an empty line before labels. + The empty line is added before the comment lines preceding the label. + """ + new_code_elements_reversed = [] + add_empty_line = False + for code_elm in code_elements[::-1]: + if add_empty_line: + if is_empty_line(code_elm): + add_empty_line = False + elif not is_comment_line(code_elm): + new_code_elements_reversed.append(CommentedCodeElement( + code_elm=CodeElementEmptyLine(), + comment=None)) + add_empty_line = False + + if isinstance(code_elm.code_elm, CodeElementLabel): + add_empty_line = True + + new_code_elements_reversed.append(code_elm) + + return new_code_elements_reversed[::-1] + + +def fix_comment_spaces(code_elements: List[CommentedCodeElement]) -> List[CommentedCodeElement]: + """ + Comments should start with exactly one space after '#'. When a comment is spread across several + lines, the next lines may start with more than one space. + Returns a copy of code_elements, where comment prefix spaces are fixed. + """ + new_code_elements = [] + allow_additional_comment_spaces = False + for code_elm in code_elements: + # Additional spaces are never allowed in inline comments. + if not is_comment_line(code_elm): + allow_additional_comment_spaces = False + + new_code_elements.append(code_elm.fix_comment_spaces(allow_additional_comment_spaces)) + + if is_comment_line(code_elm): + # Next comment line may have additional spaces. + allow_additional_comment_spaces = True + return new_code_elements diff --git a/src/starkware/cairo/lang/compiler/ast/expr.py b/src/starkware/cairo/lang/compiler/ast/expr.py new file mode 100644 index 0000000..6afc5e2 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/expr.py @@ -0,0 +1,304 @@ +import dataclasses +import re +from abc import abstractmethod +from typing import List, Optional, Sequence + +from starkware.cairo.lang.compiler.ast.cairo_types import CairoType, CastType +from starkware.cairo.lang.compiler.ast.formatting_utils import INDENTATION, LocationField +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.ast.notes import Notes, NotesField +from starkware.cairo.lang.compiler.error_handling import Location +from starkware.cairo.lang.compiler.instruction import Register +from starkware.python.expression_string import ExpressionString + + +class Expression(AstNode): + location: Optional[Location] + + def format(self): + res = str(self.to_expr_str()) + # Indent all lines except for the first. + res = res.replace('\n', '\n' + ' ' * INDENTATION) + # Remove trailing spaces. + res = re.sub(r' +\n', '\n', res) + return res + + @abstractmethod + def to_expr_str(self) -> ExpressionString: + """ + Formats the Expression and returns an ExpressionString. This is useful for automatic + insertion of parentheses (where required). + """ + + +@dataclasses.dataclass +class ExprConst(Expression): + val: int + location: Optional[Location] = LocationField + + def to_expr_str(self): + if self.val >= 0: + return ExpressionString.highest(str(self.val)) + return -ExpressionString.highest(str(-self.val)) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class ExprPyConst(Expression): + code: str + location: Optional[Location] = LocationField + + @classmethod + def from_str(cls, src: str, location: Optional[Location] = None): + assert src.startswith('%[') + assert src.endswith('%]') + code = src[2:-2] + return cls(code, location) + + def to_expr_str(self): + return ExpressionString.highest(f'%[{self.code}%]') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class ExprIdentifier(Expression): + name: str + location: Optional[Location] = LocationField + + def to_expr_str(self): + return ExpressionString.highest(self.name) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class ExprAssignment(AstNode): + """ + A code element of the form [ident=]expr. The identifier is optional. + """ + identifier: Optional[ExprIdentifier] + expr: Expression + location: Optional[Location] = LocationField + + def format(self): + if self.identifier is None: + return self.expr.format() + return f'{self.identifier.format()}={self.expr.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.identifier, self.expr] + + +@dataclasses.dataclass +class ArgList(AstNode): + """ + Represents a list of arguments (e.g., to a function call or a return statement). + For example: 'a=1, b=2'. + """ + args: List[ExprAssignment] + notes: List[Notes] + has_trailing_comma: bool + location: Optional[Location] = LocationField + + def assert_no_comments(self): + for note in self.notes: + note.assert_no_comments() + + def format(self): + if len(self.args) == 0: + assert len(self.notes) == 1 + return self.notes[0].format() + + code = '' + assert len(self.args) + 1 == len(self.notes) + for notes, arg in zip(self.notes[:-1], self.args): + if code != '': + code += ',' + if notes.empty: + code += ' ' + code += f'{notes.format()}{arg.format()}' + + # Add trailing comma at the end if necessary. + if self.has_trailing_comma: + code += ',' + code += self.notes[-1].format() + return code + + def get_children(self) -> Sequence[Optional[AstNode]]: + return self.args + + +@dataclasses.dataclass +class ExprReg(Expression): + reg: Register + location: Optional[Location] = LocationField + + def to_expr_str(self): + return ExpressionString.highest(self.reg.name.lower()) + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + +@dataclasses.dataclass +class ExprOperator(Expression): + a: Expression + op: str + b: Expression + notes: Notes = NotesField + location: Optional[Location] = LocationField + + def to_expr_str(self): + self.notes.assert_no_comments() + a = self.a.to_expr_str() + b = self.b.to_expr_str() + if not self.notes.empty: + b = b.prepend('\n') + if self.op == '+': + return a + b + elif self.op == '-': + return a - b + elif self.op == '*': + return a * b + elif self.op == '/': + return a / b + else: + raise NotImplementedError(f"Unexpected operator '{self.op}'") + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.a, self.b] + + +@dataclasses.dataclass +class ExprAddressOf(Expression): + """ + Represents an expression of the form "&expr". + """ + expr: Expression + location: Optional[Location] = LocationField + + def to_expr_str(self): + return self.expr.to_expr_str().address_of() + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.expr] + + +@dataclasses.dataclass +class ExprNeg(Expression): + val: Expression + location: Optional[Location] = LocationField + + def to_expr_str(self): + return -self.val.to_expr_str() + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.val] + + +@dataclasses.dataclass +class ExprParentheses(Expression): + val: Expression + notes: Notes = NotesField + location: Optional[Location] = LocationField + + def to_expr_str(self): + return ExpressionString.highest(f'({self.notes.format()}{str(self.val.to_expr_str())})') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.val] + + +@dataclasses.dataclass +class ExprDeref(Expression): + """ + Represents an expression of the form "[expr]". + """ + addr: Expression + notes: Notes = NotesField + location: Optional[Location] = LocationField + + def to_expr_str(self): + self.notes.assert_no_comments() + notes = '' if self.notes.empty else '\n' + return ExpressionString.highest(f'[{notes}{str(self.addr.to_expr_str())}]') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.addr] + + +@dataclasses.dataclass +class ExprDot(Expression): + """ + Represents an expression of the form "obj.member". + """ + expr: Expression + member: ExprIdentifier + location: Optional[Location] = LocationField + + def to_expr_str(self): + # If object is not an atom, add parentheses. + return ExpressionString.highest( + f'{self.expr.to_expr_str():HIGHEST}.{str(self.member.to_expr_str())}') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.expr, self.member] + + +@dataclasses.dataclass +class ExprCast(Expression): + """ + Represents a cast expression of the form "cast(expr, T)" (which transforms expr to type T). + """ + expr: Expression + dest_type: CairoType + # Cast expressions resulting from the Cairo code always have cast_type=CastType.EXPLICIT. + # 'cast_type' is only used when an ExprCast instance is created during compilation. + cast_type: CastType = CastType.EXPLICIT + notes: Notes = NotesField + location: Optional[Location] = LocationField + + def to_expr_str(self): + self.notes.assert_no_comments() + notes = '' if self.notes.empty else '\n' + return ExpressionString.highest( + f'cast({notes}{str(self.expr.to_expr_str())}, {self.dest_type.format()})') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.expr, self.dest_type] + + +@dataclasses.dataclass +class ExprTuple(Expression): + members: ArgList + location: Optional[Location] = LocationField + + def to_expr_str(self): + code = self.members.format() + return ExpressionString.highest(f'({code})') + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.members] + + +@dataclasses.dataclass +class ExprFutureLabel(Expression): + """ + Represents a future label whose current pc is not known yet. + """ + identifier: ExprIdentifier + + def to_expr_str(self): + return self.identifier.to_expr_str() + + @property + def locaion(self): + return self.identifier.location + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.identifier] diff --git a/src/starkware/cairo/lang/compiler/ast/formatting_utils.py b/src/starkware/cairo/lang/compiler/ast/formatting_utils.py new file mode 100644 index 0000000..99b87bf --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/formatting_utils.py @@ -0,0 +1,150 @@ +""" +Contains utils that help with formatting of Cairo code. +""" + +import dataclasses +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import field +from typing import List + +import marshmallow + +from starkware.cairo.lang.compiler.error_handling import LocationError + +INDENTATION = 4 +LocationField = field(default=None, hash=False, compare=False, metadata=dict( + marshmallow_field=marshmallow.fields.Field(load_only=True, dump_only=True))) +max_line_length_ctx_var: ContextVar[int] = ContextVar('max_line_length', default=100) + + +def get_max_line_length(): + return max_line_length_ctx_var.get() + + +@contextmanager +def set_max_line_length(line_length: bool): + """ + Context manager that sets max_line_length context variable. + """ + previous = get_max_line_length() + max_line_length_ctx_var.set(line_length) + yield + max_line_length_ctx_var.set(previous) + + +class FormattingError(LocationError): + pass + + +@dataclasses.dataclass +class ParticleFormattingConfig: + # The maximal line length. + allowed_line_length: int + # The indentation, starting from the second line. + line_indent: int + # The prefix of the first line. + first_line_prefix: str = '' + # At most one item per line. + one_per_line: bool = False + + +class ParticleLineBuilder: + """ + Builds particle lines, wrapping line lengths as needed. + """ + + def __init__(self, config: ParticleFormattingConfig): + self.lines: List[str] = [] + self.line = config.first_line_prefix + self.line_is_new = True + + self.config = config + + def newline(self): + """ + Opens a new line. + """ + if self.line_is_new: + return + self.lines.append(self.line) + self.line_is_new = True + self.line = ' ' * self.config.line_indent + + def add_to_line(self, string): + """ + Adds to current line, opening a new one if needed. + """ + if len(self.line) + len(string) > self.config.allowed_line_length and not self.line_is_new: + self.newline() + self.line += string + self.line_is_new = False + + def finalize(self): + """ + Finalizes the particle lines and returns the result. + """ + if self.line: + self.lines.append(self.line) + return '\n'.join(line.rstrip() for line in self.lines) + + +def create_particle_sublist(lst, end='', separator=', '): + if not lst: + # If the list is empty, return the single element 'end'. + return end + # Concatenate the 'separator' to all elements of the 'lst' and 'end' to the last one. + return [elm + separator for elm in lst[:-1]] + [lst[-1] + end] + + +def particles_in_lines(particles, config: ParticleFormattingConfig): + """ + Receives a list 'particles' that contains strings and particle sublists and generates lines + according to the following rules: + - The first line is not indented. All other lines start with 'line_indent' spaces. + - A line containing more than one particle can be no longer than 'allowed_line_length'. + - A sublist that cannot be fully concatenated to the current line opens a new line. + + Example: + particles_in_lines( + ['func f(', + create_particle_sublist(['x', 'y', 'z'], ') -> ('), + create_particle_sublist(['a', 'b', 'c'], '):')], + 12, 4) + returns '''\ + func f( + x, y, + z) -> ( + a, b, + c):\ + ''' + With a longer line length we will get the lists on the same line: + particles_in_lines( + ['func f(', + create_particle_sublist(['x', 'y', 'z'], ') -> ('), + create_particle_sublist([], '):')], + 19, 4) + returns '''\ + func f( + x, y, z) -> ():\ + ''' + """ + + builder = ParticleLineBuilder(config=config) + + for particle in particles: + if isinstance(particle, str): + builder.add_to_line(particle) + + if isinstance(particle, list): + # If the entire sublist fits in a single line, add it. + if sum(map(len, particle), config.line_indent) < config.allowed_line_length: + builder.add_to_line(''.join(particle)) + continue + builder.newline() + for member in particle: + if config.one_per_line: + builder.newline() + builder.add_to_line(member) + + return builder.finalize() diff --git a/src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py b/src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py new file mode 100644 index 0000000..d5a9c0d --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py @@ -0,0 +1,82 @@ +from starkware.cairo.lang.compiler.ast.formatting_utils import ( + ParticleFormattingConfig, create_particle_sublist, particles_in_lines) + + +def test_particles_in_lines(): + particles = [ + 'start ', + 'foo ', + 'bar ', + create_particle_sublist(['a', 'b', 'c', 'dddd', 'e', 'f'], '*'), + ' asdf', + ] + expected = """\ +start foo + bar + a, b, c, + dddd, e, + f* asdf\ +""" + assert particles_in_lines( + particles=particles, + config=ParticleFormattingConfig(allowed_line_length=12, line_indent=2), + ) == expected + + particles = [ + 'func f(', + create_particle_sublist(['x', 'y', 'z'], ') -> ('), + create_particle_sublist(['a', 'b', 'c'], '):'), + ] + expected = """\ +func f( + x, y, + z) -> ( + a, b, + c):\ +""" + assert particles_in_lines( + particles=particles, + config=ParticleFormattingConfig(allowed_line_length=12, line_indent=4), + ) == expected + + # Same particles, using one_per_line=True. + expected = """\ +func f( + x, + y, + z) -> ( + a, + b, + c):\ +""" + assert particles_in_lines( + particles=particles, + config=ParticleFormattingConfig( + allowed_line_length=12, line_indent=4, one_per_line=True), + ) == expected + + # Same particles, using one_per_line=True, longer lines. + expected = """\ +func f( + x, y, z) -> ( + a, b, c):\ +""" + assert particles_in_lines( + particles=particles, + config=ParticleFormattingConfig( + allowed_line_length=19, line_indent=4, one_per_line=True), + ) == expected + + particles = [ + 'func f(', + create_particle_sublist(['x', 'y', 'z'], ') -> ('), + create_particle_sublist([], '):'), + ] + expected = """\ +func f( + x, y, z) -> ():\ +""" + assert particles_in_lines( + particles=particles, + config=ParticleFormattingConfig(allowed_line_length=19, line_indent=4), + ) == expected diff --git a/src/starkware/cairo/lang/compiler/ast/instructions.py b/src/starkware/cairo/lang/compiler/ast/instructions.py new file mode 100644 index 0000000..8cf7b29 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/instructions.py @@ -0,0 +1,171 @@ +import dataclasses +from abc import abstractmethod +from typing import Optional, Sequence + +from starkware.cairo.lang.compiler.ast.expr import Expression, ExprIdentifier +from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.error_handling import Location + + +class InstructionBody(AstNode): + """ + Represents the instruction without the flag ap++. + """ + + location: Optional[Location] + + @abstractmethod + def format(self) -> str: + """ + Returns a string representing the instruction. + """ + + +@dataclasses.dataclass +class AssertEqInstruction(InstructionBody): + """ + Represents the instruction "a = b" for two expressions a, b. + """ + + a: Expression + b: Expression + location: Optional[Location] = LocationField + + def format(self): + return f'{self.a.format()} = {self.b.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.a, self.b] + + +@dataclasses.dataclass +class JumpInstruction(InstructionBody): + """ + Represents the instruction "jmp rel/abs". + """ + + val: Expression + relative: bool + location: Optional[Location] = LocationField + + def format(self): + return f'jmp {"rel" if self.relative else "abs"} {self.val.format()}' + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.val] + + +@dataclasses.dataclass +class JumpToLabelInstruction(InstructionBody): + """ + Represents the instruction "jmp