From 3c5a913a8085df47283b2ccd506d61e5c61ad59c Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Thu, 11 Apr 2024 11:04:27 +0900 Subject: [PATCH] feat(hermes): add sse endpoint (#1425) * add initial sse code * fix typo * add more error handling * fix formatting * revert import format * add error handling for nonexistent price feeds in the middle of sub * refactor * format * add comment * Update hermes/src/api/sse.rs Co-authored-by: Reisen * refactor * bump --------- Co-authored-by: Reisen --- hermes/Cargo.lock | 8 +- hermes/Cargo.toml | 3 +- hermes/src/api.rs | 5 ++ hermes/src/api/rest.rs | 2 + hermes/src/api/sse.rs | 173 +++++++++++++++++++++++++++++++++++++++++ hermes/src/main.rs | 4 +- 6 files changed, 189 insertions(+), 6 deletions(-) create mode 100644 hermes/src/api/sse.rs diff --git a/hermes/Cargo.lock b/hermes/Cargo.lock index 8d6a0c634..c3332c8b6 100644 --- a/hermes/Cargo.lock +++ b/hermes/Cargo.lock @@ -1796,7 +1796,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermes" -version = "0.5.3" +version = "0.5.4" dependencies = [ "anyhow", "async-trait", @@ -1839,6 +1839,7 @@ dependencies = [ "solana-sdk", "strum", "tokio", + "tokio-stream", "tonic", "tonic-build", "tower-http", @@ -5188,9 +5189,9 @@ dependencies = [ [[package]] name = "termcolor" -version = "1.4.1" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755" dependencies = [ "winapi-util", ] @@ -5385,6 +5386,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] diff --git a/hermes/Cargo.toml b/hermes/Cargo.toml index 18a3d9b27..a62616921 100644 --- a/hermes/Cargo.toml +++ b/hermes/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hermes" -version = "0.5.3" +version = "0.5.4" description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle." edition = "2021" @@ -42,6 +42,7 @@ serde_wormhole = { git = "https://github.com/wormhole-foundation/wormhol sha3 = { version = "0.10.4" } strum = { version = "0.24.1", features = ["derive"] } tokio = { version = "1.26.0", features = ["full"] } +tokio-stream = { version = "0.1.15", features = ["full"] } tonic = { version = "0.10.1", features = ["tls"] } tower-http = { version = "0.4.0", features = ["cors"] } tracing = { version = "0.1.37", features = ["log"] } diff --git a/hermes/src/api.rs b/hermes/src/api.rs index edff3f1dd..107c65fb2 100644 --- a/hermes/src/api.rs +++ b/hermes/src/api.rs @@ -23,6 +23,7 @@ use { mod doc_examples; mod metrics_middleware; mod rest; +mod sse; pub mod types; mod ws; @@ -143,6 +144,10 @@ pub async fn run(opts: RunOptions, state: ApiState) -> Result<()> { .route("/api/latest_price_feeds", get(rest::latest_price_feeds)) .route("/api/latest_vaas", get(rest::latest_vaas)) .route("/api/price_feed_ids", get(rest::price_feed_ids)) + .route( + "/v2/updates/price/stream", + get(sse::price_stream_sse_handler), + ) .route("/v2/updates/price/latest", get(rest::latest_price_updates)) .route( "/v2/updates/price/:publish_time", diff --git a/hermes/src/api/rest.rs b/hermes/src/api/rest.rs index 38133ba99..7cee21a9f 100644 --- a/hermes/src/api/rest.rs +++ b/hermes/src/api/rest.rs @@ -21,6 +21,7 @@ mod price_feed_ids; mod ready; mod v2; + pub use { get_price_feed::*, get_vaa::*, @@ -38,6 +39,7 @@ pub use { }, }; +#[derive(Debug)] pub enum RestError { BenchmarkPriceNotUnique, UpdateDataNotFound, diff --git a/hermes/src/api/sse.rs b/hermes/src/api/sse.rs new file mode 100644 index 000000000..1b690836a --- /dev/null +++ b/hermes/src/api/sse.rs @@ -0,0 +1,173 @@ +use { + crate::{ + aggregate::{ + AggregationEvent, + RequestTime, + }, + api::{ + rest::{ + verify_price_ids_exist, + RestError, + }, + types::{ + BinaryPriceUpdate, + EncodingType, + ParsedPriceUpdate, + PriceIdInput, + PriceUpdate, + }, + ApiState, + }, + }, + anyhow::Result, + axum::{ + extract::State, + response::sse::{ + Event, + KeepAlive, + Sse, + }, + }, + futures::Stream, + pyth_sdk::PriceIdentifier, + serde::Deserialize, + serde_qs::axum::QsQuery, + std::convert::Infallible, + tokio::sync::broadcast, + tokio_stream::{ + wrappers::BroadcastStream, + StreamExt as _, + }, + utoipa::IntoParams, +}; + +#[derive(Debug, Deserialize, IntoParams)] +#[into_params(parameter_in = Query)] +pub struct StreamPriceUpdatesQueryParams { + /// Get the most recent price update for this set of price feed ids. + /// + /// This parameter can be provided multiple times to retrieve multiple price updates, + /// for example see the following query string: + /// + /// ``` + /// ?ids[]=a12...&ids[]=b4c... + /// ``` + #[param(rename = "ids[]")] + #[param(example = "e62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43")] + ids: Vec, + + /// If true, include the parsed price update in the `parsed` field of each returned feed. + #[serde(default)] + encoding: EncodingType, + + /// If true, include the parsed price update in the `parsed` field of each returned feed. + #[serde(default = "default_true")] + parsed: bool, +} + +fn default_true() -> bool { + true +} + +#[utoipa::path( + get, + path = "/v2/updates/price/stream", + responses( + (status = 200, description = "Price updates retrieved successfully", body = PriceUpdate), + (status = 404, description = "Price ids not found", body = String) + ), + params(StreamPriceUpdatesQueryParams) +)] +/// SSE route handler for streaming price updates. +pub async fn price_stream_sse_handler( + State(state): State, + QsQuery(params): QsQuery, +) -> Result>>, RestError> { + let price_ids: Vec = params.ids.into_iter().map(Into::into).collect(); + + verify_price_ids_exist(&state, &price_ids).await?; + + // Clone the update_tx receiver to listen for new price updates + let update_rx: broadcast::Receiver = state.update_tx.subscribe(); + + // Convert the broadcast receiver into a Stream + let stream = BroadcastStream::new(update_rx); + + let sse_stream = stream.then(move |message| { + let state_clone = state.clone(); // Clone again to use inside the async block + let price_ids_clone = price_ids.clone(); // Clone again for use inside the async block + async move { + match message { + Ok(event) => { + match handle_aggregation_event( + event, + state_clone, + price_ids_clone, + params.encoding, + params.parsed, + ) + .await + { + Ok(price_update) => Ok(Event::default().json_data(price_update).unwrap()), + Err(e) => Ok(error_event(e)), + } + } + Err(e) => Ok(error_event(e)), + } + } + }); + + Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default())) +} + +async fn handle_aggregation_event( + event: AggregationEvent, + state: ApiState, + mut price_ids: Vec, + encoding: EncodingType, + parsed: bool, +) -> Result { + // We check for available price feed ids to ensure that the price feed ids provided exists since price feeds can be removed. + let available_price_feed_ids = crate::aggregate::get_price_feed_ids(&*state.state).await; + + price_ids.retain(|price_feed_id| available_price_feed_ids.contains(price_feed_id)); + + let price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data( + &*state.state, + &price_ids, + RequestTime::AtSlot(event.slot()), + ) + .await?; + let price_update_data = price_feeds_with_update_data.update_data; + let encoded_data: Vec = price_update_data + .into_iter() + .map(|data| encoding.encode_str(&data)) + .collect(); + let binary_price_update = BinaryPriceUpdate { + encoding, + data: encoded_data, + }; + let parsed_price_updates: Option> = if parsed { + Some( + price_feeds_with_update_data + .price_feeds + .into_iter() + .map(|price_feed| price_feed.into()) + .collect(), + ) + } else { + None + }; + + + Ok(PriceUpdate { + binary: binary_price_update, + parsed: parsed_price_updates, + }) +} + +fn error_event(e: E) -> Event { + Event::default() + .event("error") + .data(format!("Error receiving update: {:?}", e)) +} diff --git a/hermes/src/main.rs b/hermes/src/main.rs index abc6a05d7..0cd05d24d 100644 --- a/hermes/src/main.rs +++ b/hermes/src/main.rs @@ -28,14 +28,14 @@ mod state; lazy_static! { /// A static exit flag to indicate to running threads that we're shutting down. This is used to - /// gracefully shutdown the application. + /// gracefully shut down the application. /// /// We make this global based on the fact the: /// - The `Sender` side does not rely on any async runtime. /// - Exit logic doesn't really require carefully threading this value through the app. /// - The `Receiver` side of a watch channel performs the detection based on if the change /// happened after the subscribe, so it means all listeners should always be notified - /// currectly. + /// correctly. pub static ref EXIT: watch::Sender = watch::channel(false).0; }