Skip to content

Commit

Permalink
Merge with main and fix errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Juan-M-V committed Nov 17, 2022
2 parents 0c6e440 + 56490ea commit e034b03
Show file tree
Hide file tree
Showing 11 changed files with 113 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ default = ["pyo3/num-bigint", "pyo3/auto-initialize"]

[dependencies]
pyo3 = { version = "0.16.5" }
cairo-rs = { git = "https://github.com/lambdaclass/cairo-rs.git", rev = "4f36aaf46dea8cac158d0da5e80537388e048c01" }
cairo-rs = { git = "https://github.com/lambdaclass/cairo-rs.git", rev = "8c47dda53e874545895b34d675be6254878a9e7b" }
num-bigint = "0.4"
lazy_static = "1.4.0"

Expand Down
27 changes: 27 additions & 0 deletions cairo_programs/ecdsa.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
%builtins output pedersen ecdsa

from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin
from starkware.cairo.common.hash import hash2
from starkware.cairo.common.signature import verify_ecdsa_signature

func main{output_ptr : felt*, pedersen_ptr : HashBuiltin*, ecdsa_ptr : SignatureBuiltin*}():
alloc_locals

let your_eth_addr = 874739451078007766457464989774322083649278607533249481151382481072868806602
let signature_r = 1839793652349538280924927302501143912227271479439798783640887258675143576352
let signature_s = 1819432147005223164874083361865404672584671743718628757598322238853218813979
let msg = 0000000000000000000000000000000000000000000000000000000000000002

verify_ecdsa_signature(
msg,
your_eth_addr,
signature_r,
signature_s,
)


assert [output_ptr] = your_eth_addr
let output_ptr = output_ptr + 1

return ()
end
2 changes: 1 addition & 1 deletion comparer_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def new_runner(program_name: str):

if __name__ == "__main__":
program_name = sys.argv[1]
if program_name in ["blake2s_felt", "blake2s_finalize", "blake2s_integration_tests", "blake2s_hello_world_hash", "dict_squash", "squash_dict", "dict_write"]:
if program_name in ["blake2s_felt", "blake2s_finalize", "blake2s_integration_tests", "blake2s_hello_world_hash", "dict_squash", "squash_dict", "dict_write", "dict_read", "dict_update"]:
pass
else:
new_runner(program_name)
Expand Down
3 changes: 2 additions & 1 deletion hints_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_program(program_name: str):
test_program("memset")
test_program("dict_new")
test_program("dict_read")
test_program("dict_write")
#test_program("dict_write")
test_program("dict_update")
test_program("default_dict_new")
# test_program("squash_dict") # ValueError: Custom Hint Error: ValueError: Failed to get ids value
Expand Down Expand Up @@ -69,4 +69,5 @@ def test_program(program_name: str):
test_program("blake2s_finalize")
test_program("blake2s_felt")
test_program("blake2s_integration_tests")
test_program("ecdsa")
print("\nAll test have passed")
57 changes: 57 additions & 0 deletions src/ecdsa.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use std::collections::HashMap;

use cairo_rs::{
types::relocatable::Relocatable,
vm::{errors::vm_errors::VirtualMachineError, runners::builtin_runner::SignatureBuiltinRunner},
};

use num_bigint::BigInt;
use pyo3::prelude::*;

use crate::relocatable::PyRelocatable;

#[pyclass(name = "Signature")]
#[derive(Clone, Debug)]
pub struct PySignature {
signatures: HashMap<PyRelocatable, (BigInt, BigInt)>,
}

#[pymethods]
impl PySignature {
#[new]
pub fn new() -> Self {
Self {
signatures: HashMap::new(),
}
}

pub fn add_signature(&mut self, address: PyRelocatable, pair: (BigInt, BigInt)) {
self.signatures.insert(address, pair);
}
}

impl PySignature {
pub fn update_signature(
&self,
signature_builtin: &mut SignatureBuiltinRunner,
) -> Result<(), VirtualMachineError> {
for (address, pair) in self.signatures.iter() {
signature_builtin
.add_signature(Relocatable::from(address), pair)
.map_err(VirtualMachineError::MemoryError)?
}
Ok(())
}
}

impl Default for PySignature {
fn default() -> Self {
Self::new()
}
}

impl ToPyObject for PySignature {
fn to_object(&self, py: Python<'_>) -> PyObject {
self.clone().into_py(py)
}
}
11 changes: 4 additions & 7 deletions src/ids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,14 @@ impl PyIds {
struct_types: Rc::clone(&self.struct_types),
}
.into_py(py));
}

if self
} else if self
.struct_types
.contains_key(cairo_type.trim_end_matches('*'))
{
let addr =
compute_addr_from_reference(hint_ref, &self.vm.borrow(), &self.ap_tracking)?;

let dereferenced_addr = self
let hint_value = self
.vm
.borrow()
.get_relocatable(&addr)
Expand All @@ -114,7 +112,7 @@ impl PyIds {

return Ok(PyTypedId {
vm: self.vm.clone(),
hint_value: dereferenced_addr,
hint_value,
cairo_type: cairo_type.trim_end_matches('*').to_string(),
struct_types: Rc::clone(&self.struct_types),
}
Expand Down Expand Up @@ -179,11 +177,10 @@ pub struct PyTypedId {
impl PyTypedId {
#[getter]
fn __getattr__(&self, py: Python, name: &str) -> PyResult<PyObject> {
let struct_type = self.struct_types.get(&self.cairo_type).unwrap();

if name == "address_" {
return Ok(PyMaybeRelocatable::from(self.hint_value.clone()).to_object(py));
}
let struct_type = self.struct_types.get(&self.cairo_type).unwrap();

match struct_type.get(name) {
Some(member) => {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod cairo_run;
pub mod cairo_runner;
mod dict_manager;
mod ecdsa;
pub mod ids;
mod memory;
mod memory_segments;
Expand Down
2 changes: 1 addition & 1 deletion src/relocatable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub enum PyMaybeRelocatable {
}

#[pyclass(name = "Relocatable")]
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct PyRelocatable {
#[pyo3(get)]
pub segment_index: isize,
Expand Down
23 changes: 17 additions & 6 deletions src/vm_core.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::ecdsa::PySignature;
use crate::ids::PyIds;
use crate::pycell;
use crate::scope_manager::{PyEnterScope, PyExitScope};
Expand All @@ -23,7 +24,7 @@ use std::any::Any;
use std::collections::HashMap;
use std::{cell::RefCell, rc::Rc};

const GLOBAL_NAMES: [&str; 17] = [
const GLOBAL_NAMES: [&str; 18] = [
"memory",
"segments",
"ap",
Expand All @@ -33,6 +34,7 @@ const GLOBAL_NAMES: [&str; 17] = [
"vm_exit_scope",
"to_felt_or_relocatable",
"range_check_builtin",
"ecdsa_builtin",
"PRIME",
"__doc__",
"__annotations__",
Expand Down Expand Up @@ -76,8 +78,8 @@ impl PyVM {
Python::with_gil(|py| -> Result<(), VirtualMachineError> {
let memory = PyMemory::new(self);
let segments = PySegmentManager::new(self, memory.clone());
let ap = PyRelocatable::from(self.vm.borrow().get_ap());
let fp = PyRelocatable::from(self.vm.borrow().get_fp());
let ap = PyRelocatable::from((*self.vm).borrow().get_ap());
let fp = PyRelocatable::from((*self.vm).borrow().get_fp());
let ids = PyIds::new(
self,
&hint_data.ids_data,
Expand All @@ -88,8 +90,9 @@ impl PyVM {
let enter_scope = pycell!(py, PyEnterScope::new());
let exit_scope = pycell!(py, PyExitScope::new());
let range_check_builtin =
PyRangeCheck::from(self.vm.borrow().get_range_check_builtin());
let prime = self.vm.borrow().get_prime().clone();
PyRangeCheck::from((*self.vm).borrow().get_range_check_builtin());
let ecdsa_builtin = pycell!(py, PySignature::new());
let prime = (*self.vm).borrow().get_prime().clone();
let to_felt_or_relocatable = ToFeltOrRelocatableFunc;

// This line imports Python builtins. If not imported, this will run only with Python 3.10
Expand Down Expand Up @@ -126,6 +129,9 @@ impl PyVM {
globals
.set_item("range_check_builtin", range_check_builtin)
.map_err(to_vm_error)?;
globals
.set_item("ecdsa_builtin", ecdsa_builtin)
.map_err(to_vm_error)?;
globals.set_item("PRIME", prime).map_err(to_vm_error)?;
globals
.set_item(
Expand Down Expand Up @@ -155,6 +161,11 @@ impl PyVM {
py,
);

if self.vm.borrow_mut().get_signature_builtin().is_ok() {
ecdsa_builtin
.borrow()
.update_signature(self.vm.borrow_mut().get_signature_builtin()?)?;
}
enter_scope.borrow().update_scopes(exec_scopes)?;
exit_scope.borrow().update_scopes(exec_scopes)
})?;
Expand All @@ -171,7 +182,7 @@ impl PyVM {
struct_types: Rc<HashMap<String, HashMap<String, Member>>>,
constants: &HashMap<String, BigInt>,
) -> Result<(), VirtualMachineError> {
let pc_offset = self.vm.borrow().get_pc().offset;
let pc_offset = (*self.vm).borrow().get_pc().offset;

if let Some(hint_list) = hint_data_dictionary.get(&pc_offset) {
for hint_data in hint_list.iter() {
Expand Down
2 changes: 1 addition & 1 deletion tests/compare_vm_state.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ for file in $(ls $tests_path | grep .cairo$ | sed -E 's/\.cairo$//'); do
path_file="$tests_path/$file"

echo "$file"
if ! ([ "$file" = "blake2s_felt" ] || [ "$file" = "blake2s_finalize" ] || [ "$file" = "blake2s_integration_tests" ] || [ "$file" = "blake2s_hello_world_hash" ] || [ "$file" = "dict_squash" ] || [ "$file" = "squash_dict" ] || [ "$file" = "dict_write" ]); then
if ! ([ "$file" = "blake2s_felt" ] || [ "$file" = "blake2s_finalize" ] || [ "$file" = "blake2s_integration_tests" ] || [ "$file" = "blake2s_hello_world_hash" ] || [ "$file" = "dict_squash" ] || [ "$file" = "squash_dict" ] || [ "$file" = "dict_write" ] || [ "$file" = "dict_write" ] || [ "$file" = "dict_update" ] || [ "$file" = "dict_read" ]); then
if $trace; then
if ! diff -q $path_file.trace $path_file.rs.trace; then
echo "Traces for $file differ"
Expand Down
2 changes: 1 addition & 1 deletion tests/memory_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def main():
cairo_mem = {}
cairo_rs_mem = {}
name = filename2.split("/")[-1]
if name in ["blake2s_felt", "blake2s_finalize", "blake2s_hello_world_hash", "blake2s_integration_tests", "dict_squash", "squash_dict", "dict_write"]:
if name in ["blake2s_felt", "blake2s_finalize", "blake2s_hello_world_hash", "blake2s_integration_tests", "dict_squash", "squash_dict", "dict_write", "dict_read", "dict_update"]:
with open(filename1, 'rb') as f:
cairo_raw = f.read()
assert len(cairo_raw) % 40 == 0, f'{filename1}: malformed memory file from Cairo VM'
Expand Down

0 comments on commit e034b03

Please sign in to comment.