From 89262cae80e8f0a30c5a05e2d37e0a6b1d587869 Mon Sep 17 00:00:00 2001 From: Stoyan Kirov Date: Fri, 26 Apr 2024 00:43:03 +0300 Subject: [PATCH] Test contract_call; Replace OnceCell with OnceLock --- ampd/src/starknet/events/contract_call.rs | 189 ++++++++++++++++++++-- ampd/src/starknet/events/mod.rs | 10 +- 2 files changed, 185 insertions(+), 14 deletions(-) diff --git a/ampd/src/starknet/events/contract_call.rs b/ampd/src/starknet/events/contract_call.rs index d6786746b..fca60c71d 100644 --- a/ampd/src/starknet/events/contract_call.rs +++ b/ampd/src/starknet/events/contract_call.rs @@ -11,7 +11,7 @@ use crate::types::Hash; /// This is the event emitted by the gateway cairo contract on Starknet, /// when the call_contract method is called from a third party. -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct ContractCallEvent { pub destination_address: String, pub destination_chain: String, @@ -63,21 +63,27 @@ impl TryFrom for ContractCallEvent { // This field, should not exceed 252 bits (a felt's length) let destination_chain = parse_cairo_short_string(&starknet_event.keys[1])?; - // source_address represents the original callContract sender and - // is the first field in data, by the order defined in the event. - let source_address = parse_cairo_short_string(&starknet_event.data[0])?; + // source_address represents the original caller of the `call_contract` gateway + // method. It is the first field in data, by the order defined in the + // event. + // + // TODO: Not sure if `064x` is the correct formatting. Maybe we should calculate + // the pedersen hash of the felt as described here, to get the actual address, + // although I'm not sure that we can do it as described here: + // https://docs.starknet.io/documentation/architecture_and_concepts/Smart_Contracts/contract-address/ + let source_address = format!("0x{:064x}", starknet_event.data[0]); // destination_contract_address (ByteArray) is composed of FieldElements // from the second element to elemet X. let destination_address_chunks_count_felt = starknet_event.data[1]; - let destination_address_chunks_count = + let destination_address_chunks_count_u32 = u32::try_from(destination_address_chunks_count_felt)?; - let da_chunks_count_usize = usize::try_from(destination_address_chunks_count)?; + let da_chunks_count = usize::try_from(destination_address_chunks_count_u32)?; // It's + 3, because we need to offset the 0th element, pending_word and // pending_word_count, in addition to all chunks (da_chunks_count_usize) let da_elements_start_index: usize = 1; - let da_elements_end_index: usize = da_chunks_count_usize + 3; + let da_elements_end_index: usize = da_chunks_count + 3; let destination_address_byte_array: ByteArray = ByteArray::try_from( starknet_event.data[da_elements_start_index..=da_elements_end_index].to_vec(), )?; @@ -89,8 +95,20 @@ impl TryFrom for ContractCallEvent { let ph_chunk1_index: usize = da_elements_end_index + 1; let ph_chunk2_index: usize = ph_chunk1_index + 1; let mut payload_hash = [0; 32]; - let lsb: [u8; 32] = starknet_event.data[ph_chunk1_index].to_bytes_be(); - let msb: [u8; 32] = starknet_event.data[ph_chunk2_index].to_bytes_be(); + let lsb: [u8; 32] = starknet_event + .data + .get(ph_chunk1_index) + .ok_or(ContractCallError::InvalidEvent( + "payload_hash chunk 1 out of range".to_owned(), + ))? + .to_bytes_be(); + let msb: [u8; 32] = starknet_event + .data + .get(ph_chunk2_index) + .ok_or(ContractCallError::InvalidEvent( + "payload_hash chunk 2 out of range".to_owned(), + ))? + .to_bytes_be(); // most significat bits, go before least significant bits for u256 construction // check - https://docs.starknet.io/documentation/architecture_and_concepts/Smart_Contracts/serialization_of_Cairo_types/#serialization_in_u256_values @@ -105,3 +123,156 @@ impl TryFrom for ContractCallEvent { }) } } + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use ethers::types::H256; + use starknet_core::types::FieldElement; + use starknet_core::utils::starknet_keccak; + + use super::ContractCallEvent; + use crate::starknet::events::contract_call::ContractCallError; + + #[test] + fn destination_address_chunks_offset_out_of_range() { + let mut starknet_event = get_dummy_event(); + // longer chunk, which offsets the destination_address byte array out of range + starknet_event.data[1] = FieldElement::from_str( + "0x0000000000000000000000000000000000000000000000000000000000000001", + ) + .unwrap(); + + let event = ContractCallEvent::try_from(starknet_event).unwrap_err(); + assert!(matches!(event, ContractCallError::ByteArray(_))); + } + + #[test] + fn destination_address_chunks_count_too_long() { + let mut starknet_event = get_dummy_event(); + // too long for u32 + starknet_event.data[1] = FieldElement::MAX; + + let event = ContractCallEvent::try_from(starknet_event).unwrap_err(); + assert!(matches!(event, ContractCallError::FeltOutOfRange(_))); + } + + #[test] + fn invalid_dest_chain() { + let mut starknet_event = get_dummy_event(); + // too long for Cairo long string too long + starknet_event.keys[1] = FieldElement::MAX; + + let event = ContractCallEvent::try_from(starknet_event).unwrap_err(); + assert!(matches!(event, ContractCallError::Cairo(_))); + } + + #[test] + fn more_than_2_keys() { + // the payload is the word "hello" + let mut starknet_event = get_dummy_event(); + starknet_event + .keys + .push(starknet_keccak("additional_element".as_bytes())); + + let event = ContractCallEvent::try_from(starknet_event).unwrap_err(); + assert!(matches!(event, ContractCallError::InvalidEvent(_))); + } + + #[test] + fn wrong_event_type() { + // the payload is the word "hello" + let mut starknet_event = get_dummy_event(); + starknet_event.keys[0] = starknet_keccak("NOTContractCall".as_bytes()); + + let event = ContractCallEvent::try_from(starknet_event).unwrap_err(); + assert!(matches!(event, ContractCallError::InvalidEvent(_))); + } + + #[test] + fn valid_call_contract_event() { + // the payload is the word "hello" + let starknet_event = get_dummy_event(); + + let event = ContractCallEvent::try_from(starknet_event).unwrap(); + assert_eq!( + event, + ContractCallEvent { + destination_address: String::from("hello"), + destination_chain: String::from("destination_chain"), + source_address: String::from( + "0x00b3ff441a68610b30fd5e2abbf3a1548eb6ba6f3559f2862bf2dc757e5828ca" + ), + payload_hash: H256::from_slice(&[ + 28u8, 138, 255, 149, 6, 133, 194, 237, 75, 195, 23, 79, 52, 114, 40, 123, 86, + 217, 81, 123, 156, 148, 129, 39, 49, 154, 9, 167, 163, 109, 234, 200 + ]) + } + ); + } + + fn get_dummy_event() -> starknet_core::types::Event { + starknet_core::types::Event { + // I think it's a pedersen hash, but we don't use it, so any value should do + from_address: starknet_keccak("some_from_address".as_bytes()), + keys: vec![ + starknet_keccak("ContractCall".as_bytes()), + FieldElement::from_str( + "0x00000000000000000000000000000064657374696e6174696f6e5f636861696e", + ) + .unwrap(), + ], + data: vec![ + FieldElement::from_str( + "0xb3ff441a68610b30fd5e2abbf3a1548eb6ba6f3559f2862bf2dc757e5828ca", + ) + .unwrap(), + FieldElement::from_str( + "0x0000000000000000000000000000000000000000000000000000000000000000", + ) + .unwrap(), + FieldElement::from_str( + "0x00000000000000000000000000000000000000000000000000000068656c6c6f", + ) + .unwrap(), + FieldElement::from_str( + "0x0000000000000000000000000000000000000000000000000000000000000005", + ) + .unwrap(), + FieldElement::from_str( + "0x0000000000000000000000000000000056d9517b9c948127319a09a7a36deac8", + ) + .unwrap(), + FieldElement::from_str( + "0x000000000000000000000000000000001c8aff950685c2ed4bc3174f3472287b", + ) + .unwrap(), + FieldElement::from_str( + "0x0000000000000000000000000000000000000000000000000000000000000005", + ) + .unwrap(), + FieldElement::from_str( + "0x0000000000000000000000000000000000000000000000000000000000000068", + ) + .unwrap(), + FieldElement::from_str( + "0x0000000000000000000000000000000000000000000000000000000000000065", + ) + .unwrap(), + FieldElement::from_str( + "0x000000000000000000000000000000000000000000000000000000000000006c", + ) + .unwrap(), + FieldElement::from_str( + "0x000000000000000000000000000000000000000000000000000000000000006c", + ) + .unwrap(), + FieldElement::from_str( + "0x000000000000000000000000000000000000000000000000000000000000006f", + ) + .unwrap(), + ], + } + } +} diff --git a/ampd/src/starknet/events/mod.rs b/ampd/src/starknet/events/mod.rs index 28030a6e7..6b8ee7385 100644 --- a/ampd/src/starknet/events/mod.rs +++ b/ampd/src/starknet/events/mod.rs @@ -1,4 +1,4 @@ -use std::cell::OnceCell; +use std::sync::OnceLock; use starknet_core::types::FieldElement; use starknet_core::utils::starknet_keccak; @@ -6,8 +6,8 @@ use starknet_core::utils::starknet_keccak; pub mod contract_call; // Since a keccak hash over a string is a deterministic operation, -// we can use `OnceCall` to eliminate useless hash calculations. -const CALL_CONTRACT_FELT: OnceCell = OnceCell::new(); +// we can use `OnceLock` to eliminate useless hash calculations. +static CALL_CONTRACT_FELT: OnceLock = OnceLock::new(); /// All Axelar event types supported by starknet #[derive(Eq, PartialEq)] @@ -17,8 +17,8 @@ pub enum EventType { impl EventType { fn parse(event_type_felt: FieldElement) -> Option { - let binding = CALL_CONTRACT_FELT; - let contract_call_type = binding.get_or_init(|| starknet_keccak("ContractCall".as_bytes())); + let contract_call_type = + CALL_CONTRACT_FELT.get_or_init(|| starknet_keccak("ContractCall".as_bytes())); if event_type_felt == *contract_call_type { Some(EventType::ContractCall)