diff --git a/Cargo.toml b/Cargo.toml index ef447671..7cb12d42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ default = ["pyo3/num-bigint", "pyo3/auto-initialize"] [dependencies] pyo3 = { version = "0.16.5" } -cairo-rs = { git = "https://github.com/lambdaclass/cairo-rs.git", rev = "4f36aaf46dea8cac158d0da5e80537388e048c01" } +cairo-rs = { git = "https://github.com/lambdaclass/cairo-rs.git", rev = "2ddf78e20cc25e660263a0c9c1b942780d95a0e6" } num-bigint = "0.4" lazy_static = "1.4.0" diff --git a/cairo_programs/ecdsa.cairo b/cairo_programs/ecdsa.cairo new file mode 100644 index 00000000..6d93528e --- /dev/null +++ b/cairo_programs/ecdsa.cairo @@ -0,0 +1,27 @@ +%builtins output pedersen ecdsa + +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.hash import hash2 +from starkware.cairo.common.signature import verify_ecdsa_signature + +func main{output_ptr : felt*, pedersen_ptr : HashBuiltin*, ecdsa_ptr : SignatureBuiltin*}(): + alloc_locals + + let your_eth_addr = 874739451078007766457464989774322083649278607533249481151382481072868806602 + let signature_r = 1839793652349538280924927302501143912227271479439798783640887258675143576352 + let signature_s = 1819432147005223164874083361865404672584671743718628757598322238853218813979 + let msg = 0000000000000000000000000000000000000000000000000000000000000002 + + verify_ecdsa_signature( + msg, + your_eth_addr, + signature_r, + signature_s, + ) + + + assert [output_ptr] = your_eth_addr + let output_ptr = output_ptr + 1 + + return () +end diff --git a/comparer_tracer.py b/comparer_tracer.py index ee3eefad..3f02e78e 100644 --- a/comparer_tracer.py +++ b/comparer_tracer.py @@ -8,7 +8,7 @@ def new_runner(program_name: str): if __name__ == "__main__": program_name = sys.argv[1] - if program_name in ["blake2s_felt", "blake2s_finalize", "blake2s_integration_tests", "blake2s_hello_world_hash", "dict_squash", "squash_dict", "dict_write"]: + if program_name in ["blake2s_felt", "blake2s_finalize", "blake2s_integration_tests", "blake2s_hello_world_hash", "dict_squash", "squash_dict", "dict_write", "dict_read", "dict_update"]: pass else: new_runner(program_name) diff --git a/hints_tests.py b/hints_tests.py index 626289c3..6a77380a 100644 --- a/hints_tests.py +++ b/hints_tests.py @@ -69,4 +69,5 @@ def test_program(program_name: str): test_program("blake2s_finalize") test_program("blake2s_felt") test_program("blake2s_integration_tests") + test_program("ecdsa") print("\nAll test have passed") diff --git a/src/cairo_runner.rs b/src/cairo_runner.rs index 0b4a6682..044315fe 100644 --- a/src/cairo_runner.rs +++ b/src/cairo_runner.rs @@ -541,6 +541,13 @@ impl PyCairoRunner { .collect::>() .to_object(py)) } + + /// Add (or replace if already present) a custom hash builtin. + /// Returns a Relocatable with the new hash builtin base. + pub fn add_additional_hash_builtin(&self) -> PyRelocatable { + let mut vm = (*self.pyvm.vm).borrow_mut(); + self.inner.add_additional_hash_builtin(&mut vm).into() + } } #[pyclass] @@ -1221,6 +1228,10 @@ mod test { segment_index: 6, offset: 0, })], + vec![RelocatableValue(PyRelocatable { + segment_index: 7, + offset: 0, + })], ]; Python::with_gil(|py| { @@ -1522,4 +1533,43 @@ mod test { ); }); } + + /// Test that add_additional_hash_builtin() returns successfully. + #[test] + fn add_additional_hash_builtin() { + Python::with_gil(|_| { + let program = fs::read_to_string("cairo_programs/fibonacci.json").unwrap(); + let runner = PyCairoRunner::new( + program, + Some("main".to_string()), + Some("small".to_string()), + false, + ) + .unwrap(); + + let expected_relocatable = PyRelocatable { + segment_index: 0, + offset: 0, + }; + let relocatable = runner.add_additional_hash_builtin(); + assert_eq!(expected_relocatable, relocatable); + + assert_eq!( + (*runner.pyvm.vm) + .borrow() + .get_builtin_runners() + .last() + .map(|(key, _)| key.as_str()), + Some("hash_builtin"), + ); + + let mut vm = (*runner.pyvm.vm).borrow_mut(); + // Check that the segment exists by writing to it. + vm.insert_value( + &Relocatable::from((0, 0)), + MaybeRelocatable::Int(bigint!(42)), + ) + .expect("memory insert failed"); + }); + } } diff --git a/src/ecdsa.rs b/src/ecdsa.rs new file mode 100644 index 00000000..75aa051d --- /dev/null +++ b/src/ecdsa.rs @@ -0,0 +1,57 @@ +use std::collections::HashMap; + +use cairo_rs::{ + types::relocatable::Relocatable, + vm::{errors::vm_errors::VirtualMachineError, runners::builtin_runner::SignatureBuiltinRunner}, +}; + +use num_bigint::BigInt; +use pyo3::prelude::*; + +use crate::relocatable::PyRelocatable; + +#[pyclass(name = "Signature")] +#[derive(Clone, Debug)] +pub struct PySignature { + signatures: HashMap, +} + +#[pymethods] +impl PySignature { + #[new] + pub fn new() -> Self { + Self { + signatures: HashMap::new(), + } + } + + pub fn add_signature(&mut self, address: PyRelocatable, pair: (BigInt, BigInt)) { + self.signatures.insert(address, pair); + } +} + +impl PySignature { + pub fn update_signature( + &self, + signature_builtin: &mut SignatureBuiltinRunner, + ) -> Result<(), VirtualMachineError> { + for (address, pair) in self.signatures.iter() { + signature_builtin + .add_signature(Relocatable::from(address), pair) + .map_err(VirtualMachineError::MemoryError)? + } + Ok(()) + } +} + +impl Default for PySignature { + fn default() -> Self { + Self::new() + } +} + +impl ToPyObject for PySignature { + fn to_object(&self, py: Python<'_>) -> PyObject { + self.clone().into_py(py) + } +} diff --git a/src/ids.rs b/src/ids.rs index 7c7490e1..f320b519 100644 --- a/src/ids.rs +++ b/src/ids.rs @@ -96,16 +96,14 @@ impl PyIds { struct_types: Rc::clone(&self.struct_types), } .into_py(py)); - } - - if self + } else if self .struct_types .contains_key(cairo_type.trim_end_matches('*')) { let addr = compute_addr_from_reference(hint_ref, &self.vm.borrow(), &self.ap_tracking)?; - let dereferenced_addr = self + let hint_value = self .vm .borrow() .get_relocatable(&addr) @@ -114,7 +112,7 @@ impl PyIds { return Ok(PyTypedId { vm: self.vm.clone(), - hint_value: dereferenced_addr, + hint_value, cairo_type: cairo_type.trim_end_matches('*').to_string(), struct_types: Rc::clone(&self.struct_types), } @@ -179,11 +177,10 @@ pub struct PyTypedId { impl PyTypedId { #[getter] fn __getattr__(&self, py: Python, name: &str) -> PyResult { - let struct_type = self.struct_types.get(&self.cairo_type).unwrap(); - if name == "address_" { return Ok(PyMaybeRelocatable::from(self.hint_value.clone()).to_object(py)); } + let struct_type = self.struct_types.get(&self.cairo_type).unwrap(); match struct_type.get(name) { Some(member) => { diff --git a/src/lib.rs b/src/lib.rs index 8ad0c606..a4d3e45b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ pub mod cairo_run; pub mod cairo_runner; mod dict_manager; +mod ecdsa; pub mod ids; mod memory; mod memory_segments; diff --git a/src/memory.rs b/src/memory.rs index f7383a95..c2e42fc8 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -1,16 +1,18 @@ use crate::{ relocatable::{PyMaybeRelocatable, PyRelocatable}, + utils::to_py_error, vm_core::PyVM, }; use cairo_rs::{ types::relocatable::{MaybeRelocatable, Relocatable}, vm::vm_core::VirtualMachine, }; +use num_bigint::BigInt; use pyo3::{ exceptions::{PyTypeError, PyValueError}, prelude::*, }; -use std::{cell::RefCell, rc::Rc}; +use std::{borrow::Cow, cell::RefCell, rc::Rc}; const MEMORY_GET_ERROR_MSG: &str = "Failed to get value from Cairo memory"; const MEMORY_SET_ERROR_MSG: &str = "Failed to set value to Cairo memory"; @@ -69,6 +71,18 @@ impl PyMemory { .collect::>() .to_object(py)) } + + /// Return a continuous section of memory as a vector of integers. + pub fn get_range_as_ints(&self, addr: PyRelocatable, size: usize) -> PyResult> { + Ok(self + .vm + .borrow() + .get_integer_range(&Relocatable::from(&addr), size) + .map_err(to_py_error)? + .into_iter() + .map(Cow::into_owned) + .collect()) + } } #[cfg(test)] @@ -282,4 +296,104 @@ assert memory[ap] == fp assert_eq!(range.unwrap_err(), expected_error); }); } + + // Test that get_range_as_ints() works as intended. + #[test] + fn get_range_as_ints() { + let vm = PyVM::new( + BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]), + false, + ); + let memory = PyMemory::new(&vm); + + let addr = { + let mut vm = vm.vm.borrow_mut(); + let addr = vm.add_memory_segment(); + + vm.load_data( + &MaybeRelocatable::from(&addr), + vec![ + bigint!(1).into(), + bigint!(2).into(), + bigint!(3).into(), + bigint!(4).into(), + ], + ) + .expect("memory insertion failed"); + + addr + }; + + assert_eq!( + memory + .get_range_as_ints(addr.into(), 4) + .expect("get_range_as_ints() failed"), + vec![bigint!(1), bigint!(2), bigint!(3), bigint!(4)], + ); + } + + // Test that get_range_as_ints() fails when not all values are integers. + #[test] + fn get_range_as_ints_mixed() { + let vm = PyVM::new( + BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]), + false, + ); + let memory = PyMemory::new(&vm); + + let addr = { + let mut vm = vm.vm.borrow_mut(); + let addr = vm.add_memory_segment(); + + vm.load_data( + &MaybeRelocatable::from(&addr), + vec![ + bigint!(1).into(), + bigint!(2).into(), + MaybeRelocatable::RelocatableValue((1, 2).into()), + bigint!(4).into(), + ], + ) + .expect("memory insertion failed"); + + addr + }; + + memory + .get_range_as_ints(addr.into(), 4) + .expect_err("get_range_as_ints() succeeded (should have failed)"); + } + + // Test that get_range_as_ints() fails when the requested range is larger than the available + // segments. + #[test] + fn get_range_as_ints_incomplete() { + let vm = PyVM::new( + BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]), + false, + ); + let memory = PyMemory::new(&vm); + + let addr = { + let mut vm = vm.vm.borrow_mut(); + let addr = vm.add_memory_segment(); + + vm.load_data( + &MaybeRelocatable::from(&addr), + vec![ + bigint!(1).into(), + bigint!(2).into(), + bigint!(3).into(), + bigint!(4).into(), + ], + ) + .expect("memory insertion failed"); + + addr + }; + + memory + .get_range_as_ints(addr.into(), 8) + .expect_err("get_range_as_ints() succeeded (should have failed)"); + } } diff --git a/src/relocatable.rs b/src/relocatable.rs index 3b47b953..33eb1ab6 100644 --- a/src/relocatable.rs +++ b/src/relocatable.rs @@ -18,7 +18,7 @@ pub enum PyMaybeRelocatable { } #[pyclass(name = "Relocatable")] -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct PyRelocatable { #[pyo3(get)] pub segment_index: isize, diff --git a/src/vm_core.rs b/src/vm_core.rs index 3ac1c918..f481dfef 100644 --- a/src/vm_core.rs +++ b/src/vm_core.rs @@ -1,3 +1,4 @@ +use crate::ecdsa::PySignature; use crate::ids::PyIds; use crate::pycell; use crate::scope_manager::{PyEnterScope, PyExitScope}; @@ -23,7 +24,7 @@ use std::any::Any; use std::collections::HashMap; use std::{cell::RefCell, rc::Rc}; -const GLOBAL_NAMES: [&str; 17] = [ +const GLOBAL_NAMES: [&str; 18] = [ "memory", "segments", "ap", @@ -33,6 +34,7 @@ const GLOBAL_NAMES: [&str; 17] = [ "vm_exit_scope", "to_felt_or_relocatable", "range_check_builtin", + "ecdsa_builtin", "PRIME", "__doc__", "__annotations__", @@ -76,8 +78,8 @@ impl PyVM { Python::with_gil(|py| -> Result<(), VirtualMachineError> { let memory = PyMemory::new(self); let segments = PySegmentManager::new(self, memory.clone()); - let ap = PyRelocatable::from(self.vm.borrow().get_ap()); - let fp = PyRelocatable::from(self.vm.borrow().get_fp()); + let ap = PyRelocatable::from((*self.vm).borrow().get_ap()); + let fp = PyRelocatable::from((*self.vm).borrow().get_fp()); let ids = PyIds::new( self, &hint_data.ids_data, @@ -88,8 +90,9 @@ impl PyVM { let enter_scope = pycell!(py, PyEnterScope::new()); let exit_scope = pycell!(py, PyExitScope::new()); let range_check_builtin = - PyRangeCheck::from(self.vm.borrow().get_range_check_builtin()); - let prime = self.vm.borrow().get_prime().clone(); + PyRangeCheck::from((*self.vm).borrow().get_range_check_builtin()); + let ecdsa_builtin = pycell!(py, PySignature::new()); + let prime = (*self.vm).borrow().get_prime().clone(); let to_felt_or_relocatable = ToFeltOrRelocatableFunc; // This line imports Python builtins. If not imported, this will run only with Python 3.10 @@ -126,6 +129,9 @@ impl PyVM { globals .set_item("range_check_builtin", range_check_builtin) .map_err(to_vm_error)?; + globals + .set_item("ecdsa_builtin", ecdsa_builtin) + .map_err(to_vm_error)?; globals.set_item("PRIME", prime).map_err(to_vm_error)?; globals .set_item( @@ -155,6 +161,11 @@ impl PyVM { py, ); + if self.vm.borrow_mut().get_signature_builtin().is_ok() { + ecdsa_builtin + .borrow() + .update_signature(self.vm.borrow_mut().get_signature_builtin()?)?; + } enter_scope.borrow().update_scopes(exec_scopes)?; exit_scope.borrow().update_scopes(exec_scopes) })?; @@ -171,7 +182,7 @@ impl PyVM { struct_types: Rc>>, constants: &HashMap, ) -> Result<(), VirtualMachineError> { - let pc_offset = self.vm.borrow().get_pc().offset; + let pc_offset = (*self.vm).borrow().get_pc().offset; if let Some(hint_list) = hint_data_dictionary.get(&pc_offset) { for hint_data in hint_list.iter() { diff --git a/tests/compare_vm_state.sh b/tests/compare_vm_state.sh index 1d21dbab..32578f0e 100755 --- a/tests/compare_vm_state.sh +++ b/tests/compare_vm_state.sh @@ -22,7 +22,7 @@ for file in $(ls $tests_path | grep .cairo$ | sed -E 's/\.cairo$//'); do path_file="$tests_path/$file" echo "$file" - if ! ([ "$file" = "blake2s_felt" ] || [ "$file" = "blake2s_finalize" ] || [ "$file" = "blake2s_integration_tests" ] || [ "$file" = "blake2s_hello_world_hash" ] || [ "$file" = "dict_squash" ] || [ "$file" = "squash_dict" ] || [ "$file" = "dict_write" ]); then + if ! ([ "$file" = "blake2s_felt" ] || [ "$file" = "blake2s_finalize" ] || [ "$file" = "blake2s_integration_tests" ] || [ "$file" = "blake2s_hello_world_hash" ] || [ "$file" = "dict_squash" ] || [ "$file" = "squash_dict" ] || [ "$file" = "dict_write" ] || [ "$file" = "dict_write" ] || [ "$file" = "dict_update" ] || [ "$file" = "dict_read" ]); then if $trace; then if ! diff -q $path_file.trace $path_file.rs.trace; then echo "Traces for $file differ" diff --git a/tests/memory_comparator.py b/tests/memory_comparator.py index 23f3f9ea..69975268 100755 --- a/tests/memory_comparator.py +++ b/tests/memory_comparator.py @@ -8,7 +8,7 @@ def main(): cairo_mem = {} cairo_rs_mem = {} name = filename2.split("/")[-1] - if name in ["blake2s_felt", "blake2s_finalize", "blake2s_hello_world_hash", "blake2s_integration_tests", "dict_squash", "squash_dict", "dict_write"]: + if name in ["blake2s_felt", "blake2s_finalize", "blake2s_hello_world_hash", "blake2s_integration_tests", "dict_squash", "squash_dict", "dict_write", "dict_read", "dict_update"]: with open(filename1, 'rb') as f: cairo_raw = f.read() assert len(cairo_raw) % 40 == 0, f'{filename1}: malformed memory file from Cairo VM'