Skip to content

Commit

Permalink
feat(hermes): add sse endpoint (#1425)
Browse files Browse the repository at this point in the history
* add initial sse code

* fix typo

* add more error handling

* fix formatting

* revert import format

* add error handling for nonexistent price feeds in the middle of sub

* refactor

* format

* add comment

* Update hermes/src/api/sse.rs

Co-authored-by: Reisen <[email protected]>

* refactor

* bump

---------

Co-authored-by: Reisen <[email protected]>
  • Loading branch information
cctdaniel and Reisen authored Apr 11, 2024
1 parent e1f9783 commit 3c5a913
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 6 deletions.
8 changes: 5 additions & 3 deletions hermes/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion hermes/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "hermes"
version = "0.5.3"
version = "0.5.4"
description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
edition = "2021"

Expand Down Expand Up @@ -42,6 +42,7 @@ serde_wormhole = { git = "https://github.com/wormhole-foundation/wormhol
sha3 = { version = "0.10.4" }
strum = { version = "0.24.1", features = ["derive"] }
tokio = { version = "1.26.0", features = ["full"] }
tokio-stream = { version = "0.1.15", features = ["full"] }
tonic = { version = "0.10.1", features = ["tls"] }
tower-http = { version = "0.4.0", features = ["cors"] }
tracing = { version = "0.1.37", features = ["log"] }
Expand Down
5 changes: 5 additions & 0 deletions hermes/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use {
mod doc_examples;
mod metrics_middleware;
mod rest;
mod sse;
pub mod types;
mod ws;

Expand Down Expand Up @@ -143,6 +144,10 @@ pub async fn run(opts: RunOptions, state: ApiState) -> Result<()> {
.route("/api/latest_price_feeds", get(rest::latest_price_feeds))
.route("/api/latest_vaas", get(rest::latest_vaas))
.route("/api/price_feed_ids", get(rest::price_feed_ids))
.route(
"/v2/updates/price/stream",
get(sse::price_stream_sse_handler),
)
.route("/v2/updates/price/latest", get(rest::latest_price_updates))
.route(
"/v2/updates/price/:publish_time",
Expand Down
2 changes: 2 additions & 0 deletions hermes/src/api/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod price_feed_ids;
mod ready;
mod v2;


pub use {
get_price_feed::*,
get_vaa::*,
Expand All @@ -38,6 +39,7 @@ pub use {
},
};

#[derive(Debug)]
pub enum RestError {
BenchmarkPriceNotUnique,
UpdateDataNotFound,
Expand Down
173 changes: 173 additions & 0 deletions hermes/src/api/sse.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
use {
crate::{
aggregate::{
AggregationEvent,
RequestTime,
},
api::{
rest::{
verify_price_ids_exist,
RestError,
},
types::{
BinaryPriceUpdate,
EncodingType,
ParsedPriceUpdate,
PriceIdInput,
PriceUpdate,
},
ApiState,
},
},
anyhow::Result,
axum::{
extract::State,
response::sse::{
Event,
KeepAlive,
Sse,
},
},
futures::Stream,
pyth_sdk::PriceIdentifier,
serde::Deserialize,
serde_qs::axum::QsQuery,
std::convert::Infallible,
tokio::sync::broadcast,
tokio_stream::{
wrappers::BroadcastStream,
StreamExt as _,
},
utoipa::IntoParams,
};

#[derive(Debug, Deserialize, IntoParams)]
#[into_params(parameter_in = Query)]
pub struct StreamPriceUpdatesQueryParams {
/// Get the most recent price update for this set of price feed ids.
///
/// This parameter can be provided multiple times to retrieve multiple price updates,
/// for example see the following query string:
///
/// ```
/// ?ids[]=a12...&ids[]=b4c...
/// ```
#[param(rename = "ids[]")]
#[param(example = "e62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43")]
ids: Vec<PriceIdInput>,

/// If true, include the parsed price update in the `parsed` field of each returned feed.
#[serde(default)]
encoding: EncodingType,

/// If true, include the parsed price update in the `parsed` field of each returned feed.
#[serde(default = "default_true")]
parsed: bool,
}

fn default_true() -> bool {
true
}

#[utoipa::path(
get,
path = "/v2/updates/price/stream",
responses(
(status = 200, description = "Price updates retrieved successfully", body = PriceUpdate),
(status = 404, description = "Price ids not found", body = String)
),
params(StreamPriceUpdatesQueryParams)
)]
/// SSE route handler for streaming price updates.
pub async fn price_stream_sse_handler(
State(state): State<ApiState>,
QsQuery(params): QsQuery<StreamPriceUpdatesQueryParams>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, RestError> {
let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(Into::into).collect();

verify_price_ids_exist(&state, &price_ids).await?;

// Clone the update_tx receiver to listen for new price updates
let update_rx: broadcast::Receiver<AggregationEvent> = state.update_tx.subscribe();

// Convert the broadcast receiver into a Stream
let stream = BroadcastStream::new(update_rx);

let sse_stream = stream.then(move |message| {
let state_clone = state.clone(); // Clone again to use inside the async block
let price_ids_clone = price_ids.clone(); // Clone again for use inside the async block
async move {
match message {
Ok(event) => {
match handle_aggregation_event(
event,
state_clone,
price_ids_clone,
params.encoding,
params.parsed,
)
.await
{
Ok(price_update) => Ok(Event::default().json_data(price_update).unwrap()),
Err(e) => Ok(error_event(e)),
}
}
Err(e) => Ok(error_event(e)),
}
}
});

Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
}

async fn handle_aggregation_event(
event: AggregationEvent,
state: ApiState,
mut price_ids: Vec<PriceIdentifier>,
encoding: EncodingType,
parsed: bool,
) -> Result<PriceUpdate> {
// We check for available price feed ids to ensure that the price feed ids provided exists since price feeds can be removed.
let available_price_feed_ids = crate::aggregate::get_price_feed_ids(&*state.state).await;

price_ids.retain(|price_feed_id| available_price_feed_ids.contains(price_feed_id));

let price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
&*state.state,
&price_ids,
RequestTime::AtSlot(event.slot()),
)
.await?;
let price_update_data = price_feeds_with_update_data.update_data;
let encoded_data: Vec<String> = price_update_data
.into_iter()
.map(|data| encoding.encode_str(&data))
.collect();
let binary_price_update = BinaryPriceUpdate {
encoding,
data: encoded_data,
};
let parsed_price_updates: Option<Vec<ParsedPriceUpdate>> = if parsed {
Some(
price_feeds_with_update_data
.price_feeds
.into_iter()
.map(|price_feed| price_feed.into())
.collect(),
)
} else {
None
};


Ok(PriceUpdate {
binary: binary_price_update,
parsed: parsed_price_updates,
})
}

fn error_event<E: std::fmt::Debug>(e: E) -> Event {
Event::default()
.event("error")
.data(format!("Error receiving update: {:?}", e))
}
4 changes: 2 additions & 2 deletions hermes/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ mod state;

lazy_static! {
/// A static exit flag to indicate to running threads that we're shutting down. This is used to
/// gracefully shutdown the application.
/// gracefully shut down the application.
///
/// We make this global based on the fact the:
/// - The `Sender` side does not rely on any async runtime.
/// - Exit logic doesn't really require carefully threading this value through the app.
/// - The `Receiver` side of a watch channel performs the detection based on if the change
/// happened after the subscribe, so it means all listeners should always be notified
/// currectly.
/// correctly.
pub static ref EXIT: watch::Sender<bool> = watch::channel(false).0;
}

Expand Down

0 comments on commit 3c5a913

Please sign in to comment.