From 56490eaedc6f457c7bb18c106e968d068af718ba Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Wed, 16 Nov 2022 18:12:29 -0300 Subject: [PATCH] Add ecdsa builtin (#140) * Add Hash impl to Pyrelocatable * Add ecdsa Builtin * Add ecdsa.cairo * Get ecdsa_builtin from globals * Add mod ecdsa * Add __getattr__ for refenrences * Update ecdsa.cairo * Update ecdsa.cairo * Update ecdsa.cairo * Remove prints * Update ecdsa.cairo * Update hints_tests.py * Add error hundling to add_signature * Modify execute_hint flow * Update Cargo.toml * Update ecdsa.cairo * Remove unwrap() * Update hints_tests.py * ecdsa_builtin.borrow() * Exclude dict_read and dict_update.cairo tests * cargo clippy Co-authored-by: Pedro Fontana --- Cargo.toml | 2 +- cairo_programs/ecdsa.cairo | 27 ++++++++++++++++++ comparer_tracer.py | 2 +- hints_tests.py | 5 ++-- src/ecdsa.rs | 57 ++++++++++++++++++++++++++++++++++++++ src/ids.rs | 29 +++++++++++++++++-- src/lib.rs | 1 + src/relocatable.rs | 2 +- src/vm_core.rs | 23 +++++++++++---- tests/compare_vm_state.sh | 2 +- tests/memory_comparator.py | 2 +- 11 files changed, 137 insertions(+), 15 deletions(-) create mode 100644 cairo_programs/ecdsa.cairo create mode 100644 src/ecdsa.rs diff --git a/Cargo.toml b/Cargo.toml index ef447671..4f46bf32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ default = ["pyo3/num-bigint", "pyo3/auto-initialize"] [dependencies] pyo3 = { version = "0.16.5" } -cairo-rs = { git = "https://github.com/lambdaclass/cairo-rs.git", rev = "4f36aaf46dea8cac158d0da5e80537388e048c01" } +cairo-rs = { git = "https://github.com/lambdaclass/cairo-rs.git", rev = "8c47dda53e874545895b34d675be6254878a9e7b" } num-bigint = "0.4" lazy_static = "1.4.0" diff --git a/cairo_programs/ecdsa.cairo b/cairo_programs/ecdsa.cairo new file mode 100644 index 00000000..6d93528e --- /dev/null +++ b/cairo_programs/ecdsa.cairo @@ -0,0 +1,27 @@ +%builtins output pedersen ecdsa + +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.hash import hash2 +from starkware.cairo.common.signature import verify_ecdsa_signature + +func main{output_ptr : felt*, pedersen_ptr : HashBuiltin*, ecdsa_ptr : SignatureBuiltin*}(): + alloc_locals + + let your_eth_addr = 874739451078007766457464989774322083649278607533249481151382481072868806602 + let signature_r = 1839793652349538280924927302501143912227271479439798783640887258675143576352 + let signature_s = 1819432147005223164874083361865404672584671743718628757598322238853218813979 + let msg = 0000000000000000000000000000000000000000000000000000000000000002 + + verify_ecdsa_signature( + msg, + your_eth_addr, + signature_r, + signature_s, + ) + + + assert [output_ptr] = your_eth_addr + let output_ptr = output_ptr + 1 + + return () +end diff --git a/comparer_tracer.py b/comparer_tracer.py index ee3eefad..3f02e78e 100644 --- a/comparer_tracer.py +++ b/comparer_tracer.py @@ -8,7 +8,7 @@ def new_runner(program_name: str): if __name__ == "__main__": program_name = sys.argv[1] - if program_name in ["blake2s_felt", "blake2s_finalize", "blake2s_integration_tests", "blake2s_hello_world_hash", "dict_squash", "squash_dict", "dict_write"]: + if program_name in ["blake2s_felt", "blake2s_finalize", "blake2s_integration_tests", "blake2s_hello_world_hash", "dict_squash", "squash_dict", "dict_write", "dict_read", "dict_update"]: pass else: new_runner(program_name) diff --git a/hints_tests.py b/hints_tests.py index 6f0bb69f..0f9d1868 100644 --- a/hints_tests.py +++ b/hints_tests.py @@ -32,9 +32,9 @@ def test_program(program_name: str): test_program("memcpy") test_program("memset") test_program("dict_new") - test_program("dict_read") + # test_program("dict_read") # Waiting on starkware PR # test_program("dict_write") # ValueError: Custom Hint Error: AttributeError: 'PyTypeId' object has no attribute 'segment_index' - test_program("dict_update") + # test_program("dict_update") # Waiting on starkware PR test_program("default_dict_new") # test_program("squash_dict") # ValueError: Custom Hint Error: ValueError: Failed to get ids value # test_program("dict_squash") # Custom Hint Error: AttributeError: 'PyTypeId' object has no attribute 'segment_index' @@ -69,4 +69,5 @@ def test_program(program_name: str): test_program("blake2s_finalize") test_program("blake2s_felt") test_program("blake2s_integration_tests") + test_program("ecdsa") print("\nAll test have passed") diff --git a/src/ecdsa.rs b/src/ecdsa.rs new file mode 100644 index 00000000..75aa051d --- /dev/null +++ b/src/ecdsa.rs @@ -0,0 +1,57 @@ +use std::collections::HashMap; + +use cairo_rs::{ + types::relocatable::Relocatable, + vm::{errors::vm_errors::VirtualMachineError, runners::builtin_runner::SignatureBuiltinRunner}, +}; + +use num_bigint::BigInt; +use pyo3::prelude::*; + +use crate::relocatable::PyRelocatable; + +#[pyclass(name = "Signature")] +#[derive(Clone, Debug)] +pub struct PySignature { + signatures: HashMap, +} + +#[pymethods] +impl PySignature { + #[new] + pub fn new() -> Self { + Self { + signatures: HashMap::new(), + } + } + + pub fn add_signature(&mut self, address: PyRelocatable, pair: (BigInt, BigInt)) { + self.signatures.insert(address, pair); + } +} + +impl PySignature { + pub fn update_signature( + &self, + signature_builtin: &mut SignatureBuiltinRunner, + ) -> Result<(), VirtualMachineError> { + for (address, pair) in self.signatures.iter() { + signature_builtin + .add_signature(Relocatable::from(address), pair) + .map_err(VirtualMachineError::MemoryError)? + } + Ok(()) + } +} + +impl Default for PySignature { + fn default() -> Self { + Self::new() + } +} + +impl ToPyObject for PySignature { + fn to_object(&self, py: Python<'_>) -> PyObject { + self.clone().into_py(py) + } +} diff --git a/src/ids.rs b/src/ids.rs index 24293e0d..9488a139 100644 --- a/src/ids.rs +++ b/src/ids.rs @@ -84,6 +84,14 @@ impl PyIds { .ok_or_else(|| to_py_error(IDS_GET_ERROR_MSG))?; if let Some(cairo_type) = hint_ref.cairo_type.as_deref() { + let chars = cairo_type.chars().rev(); + let clear_ref = chars + .skip_while(|c| c == &'*') + .collect::() + .chars() + .rev() + .collect::(); + if self.struct_types.contains_key(cairo_type) { return Ok(PyTypedId { vm: self.vm.clone(), @@ -96,6 +104,24 @@ impl PyIds { struct_types: Rc::clone(&self.struct_types), } .into_py(py)); + } else if self.struct_types.contains_key(&clear_ref) { + let addr = + compute_addr_from_reference(hint_ref, &self.vm.borrow(), &self.ap_tracking)?; + + let hint_value = self + .vm + .borrow() + .get_relocatable(&addr) + .map_err(to_py_error)? + .into_owned(); + + return Ok(PyTypedId { + vm: self.vm.clone(), + hint_value, + cairo_type: cairo_type.to_string(), + struct_types: Rc::clone(&self.struct_types), + } + .into_py(py)); } } @@ -156,11 +182,10 @@ struct PyTypedId { impl PyTypedId { #[getter] fn __getattr__(&self, py: Python, name: &str) -> PyResult { - let struct_type = self.struct_types.get(&self.cairo_type).unwrap(); - if name == "address_" { return Ok(PyMaybeRelocatable::from(self.hint_value.clone()).to_object(py)); } + let struct_type = self.struct_types.get(&self.cairo_type).unwrap(); match struct_type.get(name) { Some(member) => { diff --git a/src/lib.rs b/src/lib.rs index 72bf910f..204dcdfd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod cairo_run; pub mod cairo_runner; +mod ecdsa; pub mod ids; mod memory; mod memory_segments; diff --git a/src/relocatable.rs b/src/relocatable.rs index ba5bc396..a339913f 100644 --- a/src/relocatable.rs +++ b/src/relocatable.rs @@ -18,7 +18,7 @@ pub enum PyMaybeRelocatable { } #[pyclass(name = "Relocatable")] -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct PyRelocatable { #[pyo3(get)] pub segment_index: isize, diff --git a/src/vm_core.rs b/src/vm_core.rs index 3ac1c918..f481dfef 100644 --- a/src/vm_core.rs +++ b/src/vm_core.rs @@ -1,3 +1,4 @@ +use crate::ecdsa::PySignature; use crate::ids::PyIds; use crate::pycell; use crate::scope_manager::{PyEnterScope, PyExitScope}; @@ -23,7 +24,7 @@ use std::any::Any; use std::collections::HashMap; use std::{cell::RefCell, rc::Rc}; -const GLOBAL_NAMES: [&str; 17] = [ +const GLOBAL_NAMES: [&str; 18] = [ "memory", "segments", "ap", @@ -33,6 +34,7 @@ const GLOBAL_NAMES: [&str; 17] = [ "vm_exit_scope", "to_felt_or_relocatable", "range_check_builtin", + "ecdsa_builtin", "PRIME", "__doc__", "__annotations__", @@ -76,8 +78,8 @@ impl PyVM { Python::with_gil(|py| -> Result<(), VirtualMachineError> { let memory = PyMemory::new(self); let segments = PySegmentManager::new(self, memory.clone()); - let ap = PyRelocatable::from(self.vm.borrow().get_ap()); - let fp = PyRelocatable::from(self.vm.borrow().get_fp()); + let ap = PyRelocatable::from((*self.vm).borrow().get_ap()); + let fp = PyRelocatable::from((*self.vm).borrow().get_fp()); let ids = PyIds::new( self, &hint_data.ids_data, @@ -88,8 +90,9 @@ impl PyVM { let enter_scope = pycell!(py, PyEnterScope::new()); let exit_scope = pycell!(py, PyExitScope::new()); let range_check_builtin = - PyRangeCheck::from(self.vm.borrow().get_range_check_builtin()); - let prime = self.vm.borrow().get_prime().clone(); + PyRangeCheck::from((*self.vm).borrow().get_range_check_builtin()); + let ecdsa_builtin = pycell!(py, PySignature::new()); + let prime = (*self.vm).borrow().get_prime().clone(); let to_felt_or_relocatable = ToFeltOrRelocatableFunc; // This line imports Python builtins. If not imported, this will run only with Python 3.10 @@ -126,6 +129,9 @@ impl PyVM { globals .set_item("range_check_builtin", range_check_builtin) .map_err(to_vm_error)?; + globals + .set_item("ecdsa_builtin", ecdsa_builtin) + .map_err(to_vm_error)?; globals.set_item("PRIME", prime).map_err(to_vm_error)?; globals .set_item( @@ -155,6 +161,11 @@ impl PyVM { py, ); + if self.vm.borrow_mut().get_signature_builtin().is_ok() { + ecdsa_builtin + .borrow() + .update_signature(self.vm.borrow_mut().get_signature_builtin()?)?; + } enter_scope.borrow().update_scopes(exec_scopes)?; exit_scope.borrow().update_scopes(exec_scopes) })?; @@ -171,7 +182,7 @@ impl PyVM { struct_types: Rc>>, constants: &HashMap, ) -> Result<(), VirtualMachineError> { - let pc_offset = self.vm.borrow().get_pc().offset; + let pc_offset = (*self.vm).borrow().get_pc().offset; if let Some(hint_list) = hint_data_dictionary.get(&pc_offset) { for hint_data in hint_list.iter() { diff --git a/tests/compare_vm_state.sh b/tests/compare_vm_state.sh index 1d21dbab..32578f0e 100755 --- a/tests/compare_vm_state.sh +++ b/tests/compare_vm_state.sh @@ -22,7 +22,7 @@ for file in $(ls $tests_path | grep .cairo$ | sed -E 's/\.cairo$//'); do path_file="$tests_path/$file" echo "$file" - if ! ([ "$file" = "blake2s_felt" ] || [ "$file" = "blake2s_finalize" ] || [ "$file" = "blake2s_integration_tests" ] || [ "$file" = "blake2s_hello_world_hash" ] || [ "$file" = "dict_squash" ] || [ "$file" = "squash_dict" ] || [ "$file" = "dict_write" ]); then + if ! ([ "$file" = "blake2s_felt" ] || [ "$file" = "blake2s_finalize" ] || [ "$file" = "blake2s_integration_tests" ] || [ "$file" = "blake2s_hello_world_hash" ] || [ "$file" = "dict_squash" ] || [ "$file" = "squash_dict" ] || [ "$file" = "dict_write" ] || [ "$file" = "dict_write" ] || [ "$file" = "dict_update" ] || [ "$file" = "dict_read" ]); then if $trace; then if ! diff -q $path_file.trace $path_file.rs.trace; then echo "Traces for $file differ" diff --git a/tests/memory_comparator.py b/tests/memory_comparator.py index 23f3f9ea..69975268 100755 --- a/tests/memory_comparator.py +++ b/tests/memory_comparator.py @@ -8,7 +8,7 @@ def main(): cairo_mem = {} cairo_rs_mem = {} name = filename2.split("/")[-1] - if name in ["blake2s_felt", "blake2s_finalize", "blake2s_hello_world_hash", "blake2s_integration_tests", "dict_squash", "squash_dict", "dict_write"]: + if name in ["blake2s_felt", "blake2s_finalize", "blake2s_hello_world_hash", "blake2s_integration_tests", "dict_squash", "squash_dict", "dict_write", "dict_read", "dict_update"]: with open(filename1, 'rb') as f: cairo_raw = f.read() assert len(cairo_raw) % 40 == 0, f'{filename1}: malformed memory file from Cairo VM'