diff --git a/hermes/Cargo.lock b/hermes/Cargo.lock index f1d77ff38a..06c74ebc54 100644 --- a/hermes/Cargo.lock +++ b/hermes/Cargo.lock @@ -1798,6 +1798,24 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "governor" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "821239e5672ff23e2a7060901fa622950bbd80b649cdaadd78d1c1767ed14eb4" +dependencies = [ + "cfg-if", + "dashmap", + "futures", + "futures-timer", + "no-std-compat", + "nonzero_ext", + "parking_lot 0.12.1", + "quanta", + "rand 0.8.5", + "smallvec", +] + [[package]] name = "h2" version = "0.3.20" @@ -1858,7 +1876,7 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermes" -version = "0.2.0" +version = "0.2.1" dependencies = [ "anyhow", "async-trait", @@ -1873,13 +1891,16 @@ dependencies = [ "derive_more", "env_logger 0.10.0", "futures", + "governor", "hex", "humantime", + "ipnet", "lazy_static", "libc", "libp2p", "log", "mock_instant", + "nonzero_ext", "prometheus-client", "pyth-sdk", "pythnet-sdk", @@ -3021,6 +3042,15 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "mach2" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d0d1830bcd151a6fc4aea1369af235b36c1528fe976b8ff678683c9995eade8" +dependencies = [ + "libc", +] + [[package]] name = "match_cfg" version = "0.1.0" @@ -3317,6 +3347,12 @@ dependencies = [ "libc", ] +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" + [[package]] name = "nohash-hasher" version = "0.2.0" @@ -3333,6 +3369,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -4052,6 +4094,22 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "quanta" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a17e662a7a8291a865152364c20c7abc5e60486ab2001e8ec10b24862de0b9ab" +dependencies = [ + "crossbeam-utils", + "libc", + "mach2", + "once_cell", + "raw-cpuid", + "wasi 0.11.0+wasi-snapshot-preview1", + "web-sys", + "winapi", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -4220,6 +4278,15 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "raw-cpuid" +version = "10.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "rayon" version = "1.7.0" diff --git a/hermes/Cargo.toml b/hermes/Cargo.toml index e8904489f7..0f590a73dd 100644 --- a/hermes/Cargo.toml +++ b/hermes/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hermes" -version = "0.2.0" +version = "0.2.1" description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle." edition = "2021" @@ -19,10 +19,13 @@ env_logger = { version = "0.10.0" } futures = { version = "0.3.28" } hex = { version = "0.4.3", features = ["serde"] } humantime = { version = "2.1.0" } +ipnet = { version = "2.8.0" } +governor = { version = "0.6.0" } lazy_static = { version = "1.4.0" } libc = { version = "0.2.140" } log = { version = "0.4.17" } mock_instant = { version = "0.3.1", features = ["sync"] } +nonzero_ext = { version = "0.3.0" } prometheus-client = { version = "0.21.1" } pyth-sdk = { version = "0.8.0" } pythnet-sdk = { path = "../pythnet/pythnet_sdk/", version = "2.0.0", features = ["strum"] } diff --git a/hermes/src/api.rs b/hermes/src/api.rs index df3b026a73..380546e545 100644 --- a/hermes/src/api.rs +++ b/hermes/src/api.rs @@ -11,10 +11,14 @@ use { routing::get, Router, }, + ipnet::IpNet, serde_qs::axum::QsQueryConfig, - std::sync::{ - atomic::Ordering, - Arc, + std::{ + net::SocketAddr, + sync::{ + atomic::Ordering, + Arc, + }, }, tokio::{ signal, @@ -36,10 +40,10 @@ pub struct ApiState { } impl ApiState { - pub fn new(state: Arc) -> Self { + pub fn new(state: Arc, ws_whitelist: Vec) -> Self { Self { state, - ws: Arc::new(ws::WsState::new()), + ws: Arc::new(ws::WsState::new(ws_whitelist)), } } } @@ -84,7 +88,7 @@ pub async fn run( )] struct ApiDoc; - let state = ApiState::new(state); + let state = ApiState::new(state, opts.rpc.ws_whitelist); // Initialize Axum Router. Note the type here is a `Router` due to the use of the // `with_state` method which replaces `Body` with `State` in the type signature. @@ -131,7 +135,7 @@ pub async fn run( // Binds the axum's server to the configured address and port. This is a blocking call and will // not return until the server is shutdown. axum::Server::try_bind(&opts.rpc.addr)? - .serve(app.into_make_service()) + .serve(app.into_make_service_with_connect_info::()) .with_graceful_shutdown(async { // Ignore Ctrl+C errors, either way we need to shut down. The main Ctrl+C handler // should also have triggered so we will let that one print the shutdown warning. diff --git a/hermes/src/api/ws.rs b/hermes/src/api/ws.rs index e71dcb3354..e7e967288f 100644 --- a/hermes/src/api/ws.rs +++ b/hermes/src/api/ws.rs @@ -21,6 +21,7 @@ use { WebSocket, WebSocketUpgrade, }, + ConnectInfo, State as AxumState, }, response::IntoResponse, @@ -35,6 +36,13 @@ use { SinkExt, StreamExt, }, + governor::{ + DefaultKeyedRateLimiter, + Quota, + RateLimiter, + }, + ipnet::IpNet, + nonzero_ext::nonzero, pyth_sdk::PriceIdentifier, serde::{ Deserialize, @@ -42,6 +50,11 @@ use { }, std::{ collections::HashMap, + net::{ + IpAddr, + SocketAddr, + }, + num::NonZeroU32, sync::{ atomic::{ AtomicUsize, @@ -54,8 +67,13 @@ use { tokio::sync::mpsc, }; -pub const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30); -pub const NOTIFICATIONS_CHAN_LEN: usize = 1000; +const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30); +const NOTIFICATIONS_CHAN_LEN: usize = 1000; +const MAX_CLIENT_MESSAGE_SIZE: usize = 100 * 1024; // 100 KiB + +/// The maximum number of bytes that can be sent per second per IP address. +/// If the limit is exceeded, the connection is closed. +const BYTES_LIMIT_PER_IP_PER_SECOND: u32 = 256 * 1024; // 256 KiB #[derive(Clone)] pub struct PriceFeedClientConfig { @@ -65,15 +83,21 @@ pub struct PriceFeedClientConfig { } pub struct WsState { - pub subscriber_counter: AtomicUsize, - pub subscribers: DashMap>, + pub subscriber_counter: AtomicUsize, + pub subscribers: DashMap>, + pub bytes_limit_whitelist: Vec, + pub rate_limiter: DefaultKeyedRateLimiter, } impl WsState { - pub fn new() -> Self { + pub fn new(whitelist: Vec) -> Self { Self { - subscriber_counter: AtomicUsize::new(0), - subscribers: DashMap::new(), + subscriber_counter: AtomicUsize::new(0), + subscribers: DashMap::new(), + rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!( + BYTES_LIMIT_PER_IP_PER_SECOND + ))), + bytes_limit_whitelist: whitelist, } } } @@ -118,20 +142,29 @@ enum ServerResponseMessage { pub async fn ws_route_handler( ws: WebSocketUpgrade, AxumState(state): AxumState, + ConnectInfo(addr): ConnectInfo, ) -> impl IntoResponse { - ws.on_upgrade(|socket| websocket_handler(socket, state)) + ws.max_message_size(MAX_CLIENT_MESSAGE_SIZE) + .on_upgrade(move |socket| websocket_handler(socket, state, addr)) } -#[tracing::instrument(skip(stream, state))] -async fn websocket_handler(stream: WebSocket, state: super::ApiState) { +#[tracing::instrument(skip(stream, state, addr))] +async fn websocket_handler(stream: WebSocket, state: super::ApiState, addr: SocketAddr) { let ws_state = state.ws.clone(); let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst); - tracing::debug!(id, "New Websocket Connection"); + tracing::debug!(id, %addr, "New Websocket Connection"); let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN); let (sender, receiver) = stream.split(); - let mut subscriber = - Subscriber::new(id, state.state.clone(), notify_receiver, receiver, sender); + let mut subscriber = Subscriber::new( + id, + addr.ip(), + state.state.clone(), + state.ws.clone(), + notify_receiver, + receiver, + sender, + ); ws_state.subscribers.insert(id, notify_sender); subscriber.run().await; @@ -143,8 +176,10 @@ pub type SubscriberId = usize; /// It listens to the store for updates and sends them to the client. pub struct Subscriber { id: SubscriberId, + ip_addr: IpAddr, closed: bool, store: Arc, + ws_state: Arc, notify_receiver: mpsc::Receiver, receiver: SplitStream, sender: SplitSink, @@ -156,15 +191,19 @@ pub struct Subscriber { impl Subscriber { pub fn new( id: SubscriberId, + ip_addr: IpAddr, store: Arc, + ws_state: Arc, notify_receiver: mpsc::Receiver, receiver: SplitStream, sender: SplitSink, ) -> Self { Self { id, + ip_addr, closed: false, store, + ws_state, notify_receiver, receiver, sender, @@ -243,19 +282,45 @@ impl Subscriber { } } + let message = serde_json::to_string(&ServerMessage::PriceUpdate { + price_feed: RpcPriceFeed::from_price_feed_update( + update, + config.verbose, + config.binary, + ), + })?; + + // Close the connection if rate limit is exceeded and the ip is not whitelisted. + if !self + .ws_state + .bytes_limit_whitelist + .contains(&self.ip_addr.into()) + && self.ws_state.rate_limiter.check_key_n( + &self.ip_addr, + NonZeroU32::new(message.len().try_into()?).ok_or(anyhow!("Empty message"))?, + ) != Ok(Ok(())) + { + tracing::info!( + self.id, + ip = %self.ip_addr, + "Rate limit exceeded. Closing connection.", + ); + self.sender + .send( + serde_json::to_string(&ServerResponseMessage::Err { + error: "Rate limit exceeded".to_string(), + })? + .into(), + ) + .await?; + self.sender.close().await?; + self.closed = true; + return Ok(()); + } + // `sender.feed` buffers a message to the client but does not flush it, so we can send // multiple messages and flush them all at once. - self.sender - .feed(Message::Text(serde_json::to_string( - &ServerMessage::PriceUpdate { - price_feed: RpcPriceFeed::from_price_feed_update( - update, - config.verbose, - config.binary, - ), - }, - )?)) - .await?; + self.sender.feed(message.into()).await?; } self.sender.flush().await?; @@ -394,4 +459,7 @@ pub async fn notify_updates(ws_state: Arc, event: AggregationEvent) { ws_state.subscribers.remove(&id); } }); + + // Clean the bytes limiting dictionary + ws_state.rate_limiter.retain_recent(); } diff --git a/hermes/src/config/rpc.rs b/hermes/src/config/rpc.rs index 6700d399b0..7d03a3bb15 100644 --- a/hermes/src/config/rpc.rs +++ b/hermes/src/config/rpc.rs @@ -1,5 +1,6 @@ use { clap::Args, + ipnet::IpNet, std::net::SocketAddr, }; @@ -14,4 +15,10 @@ pub struct Options { #[arg(default_value = DEFAULT_RPC_ADDR)] #[arg(env = "RPC_ADDR")] pub addr: SocketAddr, + + /// Whitelisted websocket ip network addresses (separated by comma). + #[arg(long = "rpc-ws-whitelist")] + #[arg(value_delimiter = ',')] + #[arg(env = "RPC_WS_WHITELIST")] + pub ws_whitelist: Vec, }