diff --git a/broker/src/crypto.rs b/broker/src/crypto.rs index 598f53e9..1f12c371 100644 --- a/broker/src/crypto.rs +++ b/broker/src/crypto.rs @@ -9,7 +9,7 @@ use std::time::Duration; use tokio::{sync::RwLock, time::timeout}; use tracing::{debug, error, warn, info}; -use crate::health::{self, Health, VaultStatus}; +use crate::serve_health::{Health, VaultStatus}; pub struct GetCertsFromPki { pki_realm: String, diff --git a/broker/src/health.rs b/broker/src/health.rs deleted file mode 100644 index c28e46c3..00000000 --- a/broker/src/health.rs +++ /dev/null @@ -1,85 +0,0 @@ -use std::{fmt::Display, sync::Arc, time::{Duration, SystemTime}, collections::HashMap}; - -use serde::{Serialize, Deserialize}; -use beam_lib::ProxyId; -use tokio::sync::RwLock; -use tracing::{info, warn}; - -#[derive(Serialize)] -#[serde(rename_all = "lowercase")] -pub enum Verdict { - Healthy, - Unhealthy, - Unknown, -} - -impl Default for Verdict { - fn default() -> Self { - Verdict::Unknown - } -} - -#[derive(Debug, Serialize, Clone, Copy, Default)] -#[serde(rename_all = "lowercase")] -pub enum VaultStatus { - Ok, - #[default] - Unknown, - OtherError, - LockedOrSealed, - Unreachable, -} - -#[derive(Debug, Serialize, Clone, Copy, Default)] -#[serde(rename_all = "lowercase")] -pub enum InitStatus { - #[default] - Unknown, - FetchingIntermediateCert, - Done -} - -#[derive(Debug, Default)] -pub struct Health { - pub vault: VaultStatus, - pub initstatus: InitStatus, - pub proxies: HashMap -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ProxyStatus { - last_connect: SystemTime, - last_disconnect: Option, - #[serde(skip)] - connections: u8, -} - -impl ProxyStatus { - pub fn online(&self) -> bool { - self.connections > 0 - } - - pub fn disconnect(&mut self) { - self.last_disconnect = Some(SystemTime::now()); - self.connections -= 1; - } - - pub fn connect(&mut self) { - self.connections += 1; - self.last_connect = SystemTime::now(); - } - - pub fn _last_seen(&self) -> SystemTime { - if self.online() { - SystemTime::now() - } else { - self.last_disconnect.expect("Should always exist as the proxy is not online") - } - } -} - -impl ProxyStatus { - pub fn new() -> ProxyStatus { - ProxyStatus { last_connect: SystemTime::now(), connections: 1, last_disconnect: None } - } -} diff --git a/broker/src/main.rs b/broker/src/main.rs index 4fe6d51f..723df949 100644 --- a/broker/src/main.rs +++ b/broker/src/main.rs @@ -2,7 +2,6 @@ mod banner; mod crypto; -mod health; mod serve; mod serve_health; mod serve_pki; @@ -15,7 +14,7 @@ mod compare_client_server_version; use std::{collections::HashMap, sync::Arc, time::Duration}; use crypto::GetCertsFromPki; -use health::{Health, InitStatus}; +use serve_health::{Health, InitStatus}; use once_cell::sync::Lazy; use shared::{config::CONFIG_CENTRAL, *, errors::SamplyBeamError}; use tokio::sync::RwLock; @@ -45,8 +44,8 @@ pub async fn main() -> anyhow::Result<()> { async fn init_broker_ca_chain(health: Arc>) { { - health.write().await.initstatus = health::InitStatus::FetchingIntermediateCert + health.write().await.initstatus = InitStatus::FetchingIntermediateCert } shared::crypto::init_ca_chain().await.expect("Failed to init broker ca chain"); - health.write().await.initstatus = health::InitStatus::Done; + health.write().await.initstatus = InitStatus::Done; } diff --git a/broker/src/serve.rs b/broker/src/serve.rs index 7b728b2e..6ea6c054 100644 --- a/broker/src/serve.rs +++ b/broker/src/serve.rs @@ -20,7 +20,7 @@ use tokio::{ }; use tracing::{debug, info, trace, warn}; -use crate::{banner, crypto, health::Health, serve_health, serve_pki, serve_tasks, compare_client_server_version}; +use crate::{banner, crypto, serve_health::Health, serve_health, serve_pki, serve_tasks, compare_client_server_version}; pub(crate) async fn serve(health: Arc>) -> anyhow::Result<()> { let app = serve_tasks::router() diff --git a/broker/src/serve_health.rs b/broker/src/serve_health.rs index 9a4d93fb..802fdd93 100644 --- a/broker/src/serve_health.rs +++ b/broker/src/serve_health.rs @@ -1,13 +1,14 @@ -use std::{sync::Arc, time::{Duration, SystemTime}}; +use std::{collections::HashMap, convert::Infallible, marker::PhantomData, sync::Arc, time::{Duration, SystemTime}}; -use axum::{extract::{State, Path}, http::StatusCode, routing::get, Json, Router, response::Response}; +use axum::{extract::{Path, State}, http::StatusCode, response::{sse::{Event, KeepAlive}, Response, Sse}, routing::get, Json, Router}; use axum_extra::{headers::{authorization::Basic, Authorization}, TypedHeader}; use beam_lib::ProxyId; +use futures_core::Stream; use serde::{Serialize, Deserialize}; use shared::{crypto_jwt::Authorized, Msg, config::CONFIG_CENTRAL}; -use tokio::sync::RwLock; +use tokio::sync::{Mutex, OwnedMutexGuard, RwLock}; -use crate::{health::{Health, VaultStatus, Verdict, ProxyStatus, InitStatus}, compare_client_server_version::log_version_mismatch}; +use crate::compare_client_server_version::log_version_mismatch; #[derive(Serialize)] struct HealthOutput { @@ -16,6 +17,58 @@ struct HealthOutput { init_status: InitStatus } +#[derive(Serialize)] +#[serde(rename_all = "lowercase")] +pub enum Verdict { + Healthy, + Unhealthy, + Unknown, +} + +impl Default for Verdict { + fn default() -> Self { + Verdict::Unknown + } +} + +#[derive(Debug, Serialize, Clone, Copy, Default)] +#[serde(rename_all = "lowercase")] +pub enum VaultStatus { + Ok, + #[default] + Unknown, + OtherError, + LockedOrSealed, + Unreachable, +} + +#[derive(Debug, Serialize, Clone, Copy, Default)] +#[serde(rename_all = "lowercase")] +pub enum InitStatus { + #[default] + Unknown, + FetchingIntermediateCert, + Done +} + +#[derive(Debug, Default)] +pub struct Health { + pub vault: VaultStatus, + pub initstatus: InitStatus, + proxies: HashMap +} + +#[derive(Debug, Clone, Default)] +struct ProxyStatus { + online_guard: Arc>> +} + +impl ProxyStatus { + pub fn is_online(&self) -> bool { + self.online_guard.try_lock().is_err() + } +} + pub(crate) fn router(health: Arc>) -> Router { Router::new() .route("/v1/health", get(handler)) @@ -46,14 +99,14 @@ async fn handler( } async fn get_all_proxies(State(state): State>>) -> Json> { - Json(state.read().await.proxies.keys().cloned().collect()) + Json(state.read().await.proxies.iter().filter(|(_, v)| v.is_online()).map(|(k, _)| k).cloned().collect()) } async fn proxy_health( State(state): State>>, Path(proxy): Path, auth: TypedHeader> -) -> Result<(StatusCode, Json), StatusCode> { +) -> Result<(StatusCode, Json), StatusCode> { let Some(ref monitoring_key) = CONFIG_CENTRAL.monitoring_api_key else { return Err(StatusCode::NOT_IMPLEMENTED); }; @@ -63,10 +116,12 @@ async fn proxy_health( } if let Some(reported_back) = state.read().await.proxies.get(&proxy) { - if reported_back.online() { - Err(StatusCode::OK) + if let Ok(last_disconnect) = reported_back.online_guard.try_lock().as_deref().copied() { + Ok((StatusCode::SERVICE_UNAVAILABLE, Json(serde_json::json!({ + "last_disconnect": last_disconnect + })))) } else { - Ok((StatusCode::SERVICE_UNAVAILABLE, Json(reported_back.clone()))) + Err(StatusCode::OK) } } else { Err(StatusCode::NOT_FOUND) @@ -76,48 +131,38 @@ async fn proxy_health( async fn get_control_tasks( State(state): State>>, proxy_auth: Authorized, -) -> StatusCode { +) -> Result, StatusCode> { let proxy_id = proxy_auth.get_from().proxy_id(); // Once this is freed the connection will be removed from the map of connected proxies again // This ensures that when the connection is dropped and therefore this response future the status of this proxy will be updated - let _connection_remover = ConnectedGuard::connect(&proxy_id, &state).await; - - // In the future, this will wait for control tasks for the given proxy - tokio::time::sleep(Duration::from_secs(60 * 60)).await; + let status_mutex = state + .write() + .await + .proxies + .entry(proxy_id) + .or_default() + .online_guard + .clone(); + let Ok(connect_guard) = tokio::time::timeout(Duration::from_secs(60), status_mutex.lock_owned()).await + else { + return Err(StatusCode::CONFLICT); + }; - StatusCode::OK + Ok(Sse::new(ForeverStream(connect_guard)).keep_alive(KeepAlive::new())) } -struct ConnectedGuard<'a> { - proxy: &'a ProxyId, - state: &'a Arc> -} +struct ForeverStream(OwnedMutexGuard>); -impl<'a> ConnectedGuard<'a> { - async fn connect(proxy: &'a ProxyId, state: &'a Arc>) -> ConnectedGuard<'a> { - { - state.write().await.proxies - .entry(proxy.clone()) - .and_modify(ProxyStatus::connect) - .or_insert(ProxyStatus::new()); - } - Self { proxy, state } +impl Stream for ForeverStream { + type Item = Result; + + fn poll_next(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { + std::task::Poll::Pending } } -impl<'a> Drop for ConnectedGuard<'a> { +impl Drop for ForeverStream { fn drop(&mut self) { - let proxy_id = self.proxy.clone(); - let map = self.state.clone(); - tokio::spawn(async move { - // We wait here for one second to give the client a bit of time to reconnect incrementing the connection count so that it will be one again after the decrement - tokio::time::sleep(Duration::from_secs(1)).await; - map.write() - .await - .proxies - .get_mut(&proxy_id) - .expect("Has to exist as we don't remove items and the constructor of this type inserts the entry") - .disconnect(); - }); + *self.0 = Some(SystemTime::now()); } -} +} \ No newline at end of file diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 59bf311a..c822db99 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -6,6 +6,7 @@ use std::time::Duration; use axum::http::{header, HeaderValue, StatusCode}; use beam_lib::AppOrProxyId; use futures::future::Ready; +use futures::{StreamExt, TryStreamExt}; use shared::{reqwest, EncryptedMessage, MsgEmpty, PlainMessage}; use shared::crypto::CryptoPublicPortion; use shared::errors::SamplyBeamError; @@ -132,8 +133,12 @@ fn spawn_controller_polling(client: SamplyHttpClient, config: Config) { const RETRY_INTERVAL: Duration = Duration::from_secs(60); tokio::spawn(async move { let mut retries_this_min = 0; - let mut reset_interval = std::pin::pin!(tokio::time::sleep(Duration::from_secs(60))); + let mut reset_interval = Instant::now() + RETRY_INTERVAL; loop { + if reset_interval < Instant::now() { + retries_this_min = 0; + reset_interval = Instant::now() + RETRY_INTERVAL; + } let body = EncryptedMessage::MsgEmpty(MsgEmpty { from: AppOrProxyId::Proxy(config.proxy_id.clone()), }); @@ -145,39 +150,48 @@ fn spawn_controller_polling(client: SamplyHttpClient, config: Config) { let req = sign_request(body, parts, &config, None).await.expect("Unable to sign request; this should always work"); // In the future this will poll actual control related tasks - match client.execute(req).await { - Ok(res) => { - match res.status() { - StatusCode::OK => { - // Process control task - }, - status @ (StatusCode::GATEWAY_TIMEOUT | StatusCode::BAD_GATEWAY) => { - if retries_this_min < 10 { - retries_this_min += 1; - debug!("Connection to broker timed out; retrying."); - } else { - warn!("Retried more then 10 times in one minute getting status code: {status}"); - tokio::time::sleep(RETRY_INTERVAL).await; - continue; - } - }, - other => { - warn!("Got unexpected status getting control tasks from broker: {other}"); - tokio::time::sleep(RETRY_INTERVAL).await; - } - }; - }, - Err(e) if e.is_timeout() => { - debug!("Connection to broker timed out; retrying: {e}"); + let res = match client.execute(req).await { + Ok(res) if res.status() == StatusCode::CONFLICT => { + error!("A beam proxy with the same id is already running!"); + std::process::exit(409); }, + Ok(res) if res.status() != StatusCode::OK => { + if retries_this_min < 10 { + retries_this_min += 1; + debug!("Unexpected status code getting control tasks from broker: {}", res.status()); + } else { + warn!("Retried more then 10 times in one minute getting status code: {}", res.status()); + tokio::time::sleep(RETRY_INTERVAL).await; + } + continue; + } + Ok(res) => res, Err(e) => { warn!("Error getting control tasks from broker; retrying in {}s: {e}", RETRY_INTERVAL.as_secs()); tokio::time::sleep(RETRY_INTERVAL).await; + continue; } }; - if reset_interval.is_elapsed() { - retries_this_min = 0; - reset_interval.as_mut().reset(Instant::now() + Duration::from_secs(60)); + let incoming = res + .bytes_stream() + .map(|result| result.map_err(|error| { + let kind = error.is_timeout().then_some(std::io::ErrorKind::TimedOut).unwrap_or(std::io::ErrorKind::Other); + std::io::Error::new(kind, format!("IO Error: {error}")) + })) + .into_async_read(); + let mut reader = async_sse::decode(incoming); + while let Some(ev) = reader.next().await { + match ev { + Ok(_)=> (), + Err(e) if e.downcast_ref::().unwrap().kind() == std::io::ErrorKind::TimedOut => { + debug!("SSE connection timed out"); + break; + }, + Err(err) => { + error!("Got error reading SSE stream: {err}"); + break; + } + }; } } });