diff --git a/crates/cli/src/cli.rs b/crates/cli/src/cli.rs index 92633007..b0466a45 100644 --- a/crates/cli/src/cli.rs +++ b/crates/cli/src/cli.rs @@ -243,6 +243,14 @@ pub struct Cli { /// Disable auto and interval mining, and mine on demand instead. #[arg(long, visible_alias = "no-mine", conflicts_with = "block_time")] pub no_mining: bool, + + /// The cors `allow_origin` header + #[arg(long, default_value = "*", help_heading = "Server options")] + pub allow_origin: String, + + /// Disable CORS. + #[arg(long, default_missing_value = "true", num_args(0..=1), conflicts_with = "allow_origin", help_heading = "Server options")] + pub no_cors: Option, } #[derive(Debug, Subcommand, Clone)] @@ -381,7 +389,9 @@ impl Cli { None }) .with_block_time(self.block_time) - .with_no_mining(self.no_mining); + .with_no_mining(self.no_mining) + .with_allow_origin(self.allow_origin) + .with_no_cors(self.no_cors); if self.emulate_evm && self.dev_system_contracts != Some(SystemContractsOptions::Local) { return Err(eyre::eyre!( diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index b0c88c97..95402c84 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -29,6 +29,7 @@ use futures::{ FutureExt, }; use jsonrpc_core::MetaIoHandler; +use jsonrpc_http_server::DomainsValidation; use logging_middleware::LoggingMiddleware; use std::fs::File; use std::{env, net::SocketAddr, str::FromStr}; @@ -50,6 +51,8 @@ async fn build_json_http< log_level_filter: LevelFilter, node: InMemoryNode, enable_health_api: bool, + cors_allow_origin: String, + disable_cors: bool, ) -> tokio::task::JoinHandle<()> { let (sender, recv) = oneshot::channel::<()>(); @@ -76,9 +79,15 @@ async fn build_json_http< .build() .unwrap(); + let allow_origin = if disable_cors { + "null" + } else { + &cors_allow_origin + }; let mut builder = jsonrpc_http_server::ServerBuilder::new(io_handler) .threads(1) - .event_loop_executor(runtime.handle().clone()); + .event_loop_executor(runtime.handle().clone()) + .cors(DomainsValidation::AllowOnly(vec![allow_origin.into()])); if enable_health_api { builder = builder.health_api(("/health", "web3_clientVersion")); @@ -324,6 +333,8 @@ async fn main() -> anyhow::Result<()> { log_level_filter, node.clone(), config.health_check_endpoint, + config.allow_origin.clone(), + config.no_cors, ) })) .await; diff --git a/crates/config/src/config.rs b/crates/config/src/config.rs index d33592e1..e0291d43 100644 --- a/crates/config/src/config.rs +++ b/crates/config/src/config.rs @@ -111,6 +111,10 @@ pub struct TestNodeConfig { pub max_transactions: usize, /// Disable automatic sealing mode and use `BlockSealer::Noop` instead pub no_mining: bool, + /// The cors `allow_origin` header + pub allow_origin: String, + /// Disable CORS if true + pub no_cors: bool, } impl Default for TestNodeConfig { @@ -171,6 +175,10 @@ impl Default for TestNodeConfig { no_mining: false, max_transactions: 1000, + + // Server configuration + allow_origin: "*".to_string(), + no_cors: false, } } } @@ -869,4 +877,20 @@ impl TestNodeConfig { self.no_mining = no_mining; self } + + // Set allow_origin CORS header + #[must_use] + pub fn with_allow_origin(mut self, allow_origin: String) -> Self { + self.allow_origin = allow_origin; + self + } + + // Enable or disable CORS + #[must_use] + pub fn with_no_cors(mut self, no_cors: Option) -> Self { + if let Some(no_cors) = no_cors { + self.no_cors = no_cors; + } + self + } } diff --git a/e2e-tests-rust/src/lib.rs b/e2e-tests-rust/src/lib.rs index 59a34eb4..599f42f9 100644 --- a/e2e-tests-rust/src/lib.rs +++ b/e2e-tests-rust/src/lib.rs @@ -5,4 +5,4 @@ mod provider; mod utils; pub use ext::{ReceiptExt, ZksyncWalletProviderExt}; -pub use provider::{init_testing_provider, AnvilZKsyncApi, TestingProvider, DEFAULT_TX_VALUE}; +pub use provider::{init_testing_provider, init_testing_provider_with_http_headers, AnvilZKsyncApi, TestingProvider, DEFAULT_TX_VALUE}; diff --git a/e2e-tests-rust/src/provider/mod.rs b/e2e-tests-rust/src/provider/mod.rs index cf6e0b6d..1dbed218 100644 --- a/e2e-tests-rust/src/provider/mod.rs +++ b/e2e-tests-rust/src/provider/mod.rs @@ -2,4 +2,4 @@ mod anvil_zksync; mod testing; pub use anvil_zksync::AnvilZKsyncApi; -pub use testing::{init_testing_provider, TestingProvider, DEFAULT_TX_VALUE}; +pub use testing::{init_testing_provider, init_testing_provider_with_http_headers, TestingProvider, DEFAULT_TX_VALUE}; diff --git a/e2e-tests-rust/src/provider/testing.rs b/e2e-tests-rust/src/provider/testing.rs index 48e713b6..897d60e4 100644 --- a/e2e-tests-rust/src/provider/testing.rs +++ b/e2e-tests-rust/src/provider/testing.rs @@ -1,21 +1,32 @@ -use crate::utils::LockedPort; +use crate::utils::{LockedPort,get_node_binary_path}; use crate::ReceiptExt; use alloy::network::primitives::{BlockTransactionsKind, HeaderResponse as _}; use alloy::network::{Network, ReceiptResponse as _, TransactionBuilder}; use alloy::primitives::{Address, U256}; +use alloy::signers::local::LocalSigner; use alloy::providers::{ PendingTransaction, PendingTransactionBuilder, PendingTransactionError, Provider, RootProvider, SendableTx, WalletProvider, }; -use alloy::rpc::types::{Block, TransactionRequest}; -use alloy::transports::http::{reqwest, Http}; +use alloy::rpc::{ + types::{Block, TransactionRequest}, + client::RpcClient, +}; +use alloy::transports::http::{ + reqwest, + reqwest::{ + header::HeaderMap, + Client, + }, + Http +}; use alloy::transports::{RpcError, Transport, TransportErrorKind, TransportResult}; use alloy_zksync::network::header_response::HeaderResponse; use alloy_zksync::network::receipt_response::ReceiptResponse; use alloy_zksync::network::transaction_response::TransactionResponse; use alloy_zksync::network::Zksync; -use alloy_zksync::node_bindings::EraTestNode; -use alloy_zksync::provider::{zksync_provider, ProviderBuilderExt}; +use alloy_zksync::node_bindings::{EraTestNode,EraTestNodeError::NoKeysAvailable}; +use alloy_zksync::provider::{zksync_provider, ProviderBuilderExt, layers::era_test_node::EraTestNodeLayer}; use alloy_zksync::wallet::ZksyncWallet; use anyhow::Context as _; use itertools::Itertools; @@ -71,10 +82,7 @@ pub async fn init_testing_provider( .with_recommended_fillers() .on_era_test_node_with_wallet_and_config(|node| { f(node - .path( - std::env::var("ANVIL_ZKSYNC_BINARY_PATH") - .unwrap_or("../target/release/anvil-zksync".to_string()), - ) + .path(get_node_binary_path()) .port(locked_port.port)) }); @@ -93,6 +101,66 @@ pub async fn init_testing_provider( }) } +// Init testing provider which sends specified HTTP headers e.g. for authentication +// Outside of `TestingProvider` to avoid specifying `P` +pub async fn init_testing_provider_with_http_headers( + headers: HeaderMap, + f: impl FnOnce(EraTestNode) -> EraTestNode, +) -> anyhow::Result< + TestingProvider>, Http>, +> { + use alloy::signers::Signer; + + let locked_port = LockedPort::acquire_unused().await?; + let node_layer = EraTestNodeLayer::from( + f( + EraTestNode::new() + .path(get_node_binary_path()) + .port(locked_port.port) + ) + ); + + let client_with_headers = Client::builder().default_headers(headers).build()?; + let rpc_url = node_layer.endpoint_url(); + let http = Http::with_client(client_with_headers, rpc_url); + let rpc_client = RpcClient::new(http, true); + + let default_keys = node_layer.instance().keys().to_vec(); + let (default_key, remaining_keys) = default_keys + .split_first() + .ok_or(NoKeysAvailable)?; + + let default_signer = LocalSigner::from(default_key.clone()) + .with_chain_id(Some(node_layer.instance().chain_id())); + let mut wallet = ZksyncWallet::from(default_signer); + + for key in remaining_keys { + let signer = LocalSigner::from(key.clone()); + wallet.register_signer(signer) + } + + let provider = zksync_provider() + .with_recommended_fillers() + .wallet(wallet) + .layer(node_layer) + .on_client(rpc_client); + + // Grab default rich accounts right after init. Note that subsequent calls to this method + // might return different value as wallet's signers are dynamic and can be changed by the user. + let rich_accounts = provider.signer_addresses().collect::>(); + // Wait for anvil-zksync to get up and be able to respond + // Ignore error response (should not fail here if provider is used with intentionally wrong origin for testing purposes) + let _ = provider.get_chain_id().await; + // Explicitly unlock the port to showcase why we waited above + drop(locked_port); + + Ok(TestingProvider { + inner: provider, + rich_accounts, + _pd: Default::default(), + }) +} + impl TestingProvider where P: FullZksyncProvider, @@ -383,13 +451,7 @@ where self } - /// Builder-pattern method for setting the chain id. - pub fn with_chain_id(mut self, id: u64) -> Self { - self.inner = self.inner.with_chain_id(id); - self - } - - /// Builder-pattern method for setting the recipient. + /// Builder-pattern method for setting the receiver. pub fn with_to(mut self, to: Address) -> Self { self.inner = self.inner.with_to(to); self @@ -401,6 +463,12 @@ where self } + /// Builder-pattern method for setting the chain id. + pub fn with_chain_id(mut self, id: u64) -> Self { + self.inner = self.inner.with_chain_id(id); + self + } + /// Submits transaction to the node. /// /// This does not wait for the transaction to be confirmed, but returns a [`PendingTransactionFinalizable`] diff --git a/e2e-tests-rust/src/utils.rs b/e2e-tests-rust/src/utils.rs index d7c641f7..2a62bd53 100644 --- a/e2e-tests-rust/src/utils.rs +++ b/e2e-tests-rust/src/utils.rs @@ -71,3 +71,8 @@ impl Drop for LockedPort { .unwrap(); } } + +pub fn get_node_binary_path() -> String { + std::env::var("ANVIL_ZKSYNC_BINARY_PATH") + .unwrap_or("../target/release/anvil-zksync".to_string()) +} diff --git a/e2e-tests-rust/tests/lib.rs b/e2e-tests-rust/tests/lib.rs index 05a4f300..42b5ccdc 100644 --- a/e2e-tests-rust/tests/lib.rs +++ b/e2e-tests-rust/tests/lib.rs @@ -2,12 +2,13 @@ use alloy::network::ReceiptResponse; use alloy::providers::ext::AnvilApi; use alloy::providers::Provider; use anvil_zksync_e2e_tests::{ - init_testing_provider, AnvilZKsyncApi, ReceiptExt, ZksyncWalletProviderExt, DEFAULT_TX_VALUE, + init_testing_provider, init_testing_provider_with_http_headers, AnvilZKsyncApi, ReceiptExt, ZksyncWalletProviderExt, DEFAULT_TX_VALUE, }; use alloy::{ primitives::U256, signers::local::PrivateKeySigner, }; +use alloy::transports::http::reqwest::header::{HeaderMap, HeaderValue, ORIGIN}; use std::convert::identity; use std::time::Duration; @@ -363,3 +364,37 @@ async fn set_chain_id() -> anyhow::Result<()> { Ok(()) } + +#[tokio::test] +async fn cli_no_cors() -> anyhow::Result<()> { + let mut headers = HeaderMap::new(); + headers.insert(ORIGIN, HeaderValue::from_static("http://some.origin")); + + // Verify all origins are allowed by default + let provider = init_testing_provider_with_http_headers(headers.clone(), identity).await?; + provider.get_chain_id().await?; + + // Verify no origins are allowed with --no-cors + let provider_with_no_cors = init_testing_provider_with_http_headers(headers.clone(), |node| node.arg("--no-cors=true")).await?; + let error_resp = provider_with_no_cors.get_chain_id().await.unwrap_err(); + assert_eq!(error_resp.to_string().contains("Origin of the request is not whitelisted"), true); + + Ok(()) +} + +#[tokio::test] +async fn cli_allow_origin() -> anyhow::Result<()> { + let mut headers = HeaderMap::new(); + headers.insert(ORIGIN, HeaderValue::from_static("http://some.origin")); + + // Verify allowed origin can make requests + let provider_with_allowed_origin = init_testing_provider_with_http_headers(headers.clone(), |node| node.arg("--allow-origin=http://some.origin")).await?; + provider_with_allowed_origin.get_chain_id().await?; + + // Verify different origin is not allowed + let provider_with_not_allowed_origin = init_testing_provider_with_http_headers(headers.clone(), |node| node.arg("--allow-origin=http://other.origin")).await?; + let error_resp = provider_with_not_allowed_origin.get_chain_id().await.unwrap_err(); + assert_eq!(error_resp.to_string().contains("Origin of the request is not whitelisted"), true); + + Ok(()) +}