From 91af3e9ce9c7f69f32f6312b9849c02363e29387 Mon Sep 17 00:00:00 2001 From: Giang Minh Date: Sun, 4 Feb 2024 16:13:13 +0700 Subject: [PATCH] change gateway code as RFC-0003 --- Cargo.lock | 1 + README.md | 8 +- packages/cluster/src/define/mod.rs | 3 +- packages/cluster/src/define/rpc/gateway.rs | 3 +- packages/cluster/src/implement/mod.rs | 8 +- packages/cluster/src/implement/server.rs | 5 +- servers/media-server/Cargo.toml | 1 + servers/media-server/scripts/gateway.sh | 9 + .../media-server/scripts/gateway_global.sh | 7 - servers/media-server/scripts/gateway_inner.sh | 11 -- servers/media-server/scripts/gateway_other.sh | 10 + servers/media-server/scripts/media_rtmp.sh | 2 +- servers/media-server/scripts/media_sip.sh | 2 +- servers/media-server/scripts/media_webrtc.sh | 2 +- servers/media-server/src/main.rs | 56 ++---- servers/media-server/src/server/gateway.rs | 115 ++++++----- .../media-server/src/server/gateway/logic.rs | 135 +++++++++---- .../server/gateway/logic/global_registry.rs | 180 ++++++++++-------- .../server/gateway/logic/inner_registry.rs | 53 ++++-- .../media-server/src/server/gateway/rpc.rs | 2 +- .../src/server/gateway/rpc/http.rs | 2 +- .../src/server/gateway/webrtc_route.rs | 13 +- servers/media-server/src/server/rtmp.rs | 11 +- servers/media-server/src/server/webrtc.rs | 11 +- 24 files changed, 374 insertions(+), 276 deletions(-) create mode 100755 servers/media-server/scripts/gateway.sh delete mode 100755 servers/media-server/scripts/gateway_global.sh delete mode 100755 servers/media-server/scripts/gateway_inner.sh create mode 100755 servers/media-server/scripts/gateway_other.sh diff --git a/Cargo.lock b/Cargo.lock index 20ae6cd6..f11a666b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -429,6 +429,7 @@ dependencies = [ "atm0s-media-server-transport-sip", "atm0s-media-server-transport-webrtc", "atm0s-media-server-utils", + "bincode", "clap", "futures", "log", diff --git a/README.md b/README.md index 6983eb08..52db5c40 100644 --- a/README.md +++ b/README.md @@ -127,19 +127,19 @@ After that, we can access `http://localhost:3000/samples` to see all embedded sa In cluster mode, each module needs to be on a separate node. This setup can run on a single machine or multiple machines, whether they are connected via a public or private network. -The Inner-Gateway module routes user traffic to the most suitable media server node. +The Gateway node routes user traffic to the most suitable media server node. ```bash atm0s-media-server --node-id 10 --sdn-port 10010 --http-port 3000 gateway ``` -Afterward, the gateway prints out its address in the format: 10@/ip4/127.0.0.1/udp/10001/ip4/127.0.0.1/tcp/10001. This address serves as the seed node for other nodes joining the cluster. +Afterward, the gateway prints out its address in the format: 10@/ip4/127.0.0.1/udp/10010/ip4/127.0.0.1/tcp/10010. This address serves as the seed node for other nodes joining the cluster. -The WebRTC module serves users with either an SDK or a Whip, Whep client. +The WebRTC node serves users with either an SDK or a Whip, Whep client. ```bash atm0s-media-server --node-id 21 --http-port 3001 --seeds ABOVE_GATEWAY_ADDR webrtc ``` -The RTMP module serves users with an RTMP broadcaster such as OBS or Streamyard. +The RTMP node serves users with an RTMP broadcaster such as OBS or Streamyard. ```bash atm0s-media-server --node-id 30 --seeds ABOVE_GATEWAY_ADDR rtmp ``` diff --git a/packages/cluster/src/define/mod.rs b/packages/cluster/src/define/mod.rs index e2be02a4..2b41c6e4 100644 --- a/packages/cluster/src/define/mod.rs +++ b/packages/cluster/src/define/mod.rs @@ -61,7 +61,6 @@ where fn build(&mut self, room_id: &str, peer_id: &str) -> C; } -pub const GLOBAL_GATEWAY_SERVICE: u8 = 100; -pub const INNER_GATEWAY_SERVICE: u8 = 101; +pub const GATEWAY_SERVICE: u8 = 101; pub const MEDIA_SERVER_SERVICE: u8 = 102; pub const CONNECTOR_SERVICE: u8 = 103; diff --git a/packages/cluster/src/define/rpc/gateway.rs b/packages/cluster/src/define/rpc/gateway.rs index 5fa8da5e..8038528c 100644 --- a/packages/cluster/src/define/rpc/gateway.rs +++ b/packages/cluster/src/define/rpc/gateway.rs @@ -20,7 +20,7 @@ pub struct ServiceInfo { #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, IntoVecU8, TryFromSliceU8)] pub struct NodePing { pub node_id: u32, - pub group: String, + pub zone: String, pub location: Option<(F32<2>, F32<2>)>, pub webrtc: Option, pub rtmp: Option, @@ -54,6 +54,7 @@ pub struct QueryBestNodesRequest { #[derive(Debug, Serialize, Deserialize, Object, PartialEq, Eq, IntoVecU8, TryFromSliceU8, Clone)] pub struct QueryBestNodesResponse { pub nodes: Vec, + pub service_id: u8, } //TODO test this diff --git a/packages/cluster/src/implement/mod.rs b/packages/cluster/src/implement/mod.rs index a6cdf8e0..d789b761 100644 --- a/packages/cluster/src/implement/mod.rs +++ b/packages/cluster/src/implement/mod.rs @@ -26,7 +26,7 @@ mod tests { #[async_std::test] async fn subscribe_room() { - let (mut server, _rpc) = ServerSdn::new( + let (mut server, _rpc, _pubsub) = ServerSdn::new( 1, 0, 100, @@ -104,7 +104,7 @@ mod tests { #[async_std::test] async fn subscribe_peer() { - let (mut server, _rpc) = ServerSdn::new( + let (mut server, _rpc, _pubsub) = ServerSdn::new( 2, 0, 100, @@ -182,7 +182,7 @@ mod tests { #[async_std::test] async fn subscribe_stream() { - let (mut server, _rpc) = ServerSdn::new( + let (mut server, _rpc, _pubsub) = ServerSdn::new( 3, 0, 100, @@ -258,7 +258,7 @@ mod tests { #[async_std::test] async fn rpc() { - let (_server, mut rpc) = ServerSdn::new( + let (_server, mut rpc, _pubsub) = ServerSdn::new( 4, 0, 100, diff --git a/packages/cluster/src/implement/server.rs b/packages/cluster/src/implement/server.rs index ed7031c9..42caa6fa 100644 --- a/packages/cluster/src/implement/server.rs +++ b/packages/cluster/src/implement/server.rs @@ -49,7 +49,7 @@ pub struct ServerSdn { } impl ServerSdn { - pub async fn new(node_id: NodeId, port: u16, service_id: u8, config: ServerSdnConfig) -> (Self, RpcEndpointSdn) { + pub async fn new(node_id: NodeId, port: u16, service_id: u8, config: ServerSdnConfig) -> (Self, RpcEndpointSdn, PubsubSdk) { let mut node_addr_builder = NodeAddrBuilder::new(node_id); let udp_socket = UdpTransport::prepare(port, &mut node_addr_builder).await; let tcp_listener = TcpTransport::prepare(port, &mut node_addr_builder).await; @@ -103,12 +103,13 @@ impl ServerSdn { Self { node_id, node_addr: node_addr_builder.addr(), - pubsub_sdk, + pubsub_sdk: pubsub_sdk.clone(), kv_sdk, join_handler: Some(join_handler), rpc_emitter: rpc_box.emitter(), }, RpcEndpointSdn { rpc_box }, + pubsub_sdk, ) } } diff --git a/servers/media-server/Cargo.toml b/servers/media-server/Cargo.toml index 7230ad2d..0cbfc2a9 100644 --- a/servers/media-server/Cargo.toml +++ b/servers/media-server/Cargo.toml @@ -37,6 +37,7 @@ md5 = {version = "0.7.0", optional = true } rand = "0.8.5" yaque = { version = "0.6.6", optional = true } maxminddb = { version = "0.24.0", optional = true } +bincode = { version = "1" } [dev-dependencies] md5 = "0.7.0" diff --git a/servers/media-server/scripts/gateway.sh b/servers/media-server/scripts/gateway.sh new file mode 100755 index 00000000..7c106562 --- /dev/null +++ b/servers/media-server/scripts/gateway.sh @@ -0,0 +1,9 @@ +cargo run --package atm0s-media-server -- \ +--node-id 11 \ +--http-port 8011 \ +--sdn-port 10011 \ +--sdn-zone zone1 \ +gateway \ +--lat 37.7749 \ +--lng 122.4194 \ +--geoip-db ../../../maxminddb-data/GeoLite2-City.mmdb diff --git a/servers/media-server/scripts/gateway_global.sh b/servers/media-server/scripts/gateway_global.sh deleted file mode 100755 index 1334ad1d..00000000 --- a/servers/media-server/scripts/gateway_global.sh +++ /dev/null @@ -1,7 +0,0 @@ -cargo run --package atm0s-media-server -- \ ---node-id 1 \ ---http-port 8001 \ ---sdn-port 10001 \ -gateway \ ---mode global \ ---geoip-db ../../../maxminddb-data/GeoLite2-City.mmdb \ No newline at end of file diff --git a/servers/media-server/scripts/gateway_inner.sh b/servers/media-server/scripts/gateway_inner.sh deleted file mode 100755 index c6a3b0d0..00000000 --- a/servers/media-server/scripts/gateway_inner.sh +++ /dev/null @@ -1,11 +0,0 @@ -cargo run --package atm0s-media-server -- \ ---node-id 11 \ ---http-port 8011 \ ---sdn-port 10011 \ ---sdn-group group1 \ ---seeds 1@/ip4/127.0.0.1/udp/10001/ip4/127.0.0.1/tcp/10001 \ -gateway \ ---mode inner \ ---group local \ ---lat 37.7749 \ ---lng 122.4194 diff --git a/servers/media-server/scripts/gateway_other.sh b/servers/media-server/scripts/gateway_other.sh new file mode 100755 index 00000000..41b8373c --- /dev/null +++ b/servers/media-server/scripts/gateway_other.sh @@ -0,0 +1,10 @@ +cargo run --package atm0s-media-server -- \ +--node-id 12 \ +--http-port 8012 \ +--sdn-port 10012 \ +--sdn-zone zone2 \ +--seeds 11@/ip4/127.0.0.1/udp/10011/ip4/127.0.0.1/tcp/10011 \ +gateway \ +--lat 47.7749 \ +--lng 112.4194 \ +--geoip-db ../../../maxminddb-data/GeoLite2-City.mmdb diff --git a/servers/media-server/scripts/media_rtmp.sh b/servers/media-server/scripts/media_rtmp.sh index 28d2981c..4a3a4004 100755 --- a/servers/media-server/scripts/media_rtmp.sh +++ b/servers/media-server/scripts/media_rtmp.sh @@ -2,6 +2,6 @@ cargo run --package atm0s-media-server -- \ --node-id 21 \ --http-port 8021 \ --sdn-port 10021 \ ---sdn-group group1 \ +--sdn-zone zone1 \ --seeds 11@/ip4/127.0.0.1/udp/10011/ip4/127.0.0.1/tcp/10011 \ rtmp \ No newline at end of file diff --git a/servers/media-server/scripts/media_sip.sh b/servers/media-server/scripts/media_sip.sh index 6b2440ed..42900b10 100755 --- a/servers/media-server/scripts/media_sip.sh +++ b/servers/media-server/scripts/media_sip.sh @@ -2,6 +2,6 @@ cargo run --package atm0s-media-server -- \ --node-id 31 \ --http-port 8031 \ --sdn-port 10031 \ ---sdn-group group1 \ +--sdn-zone zone1 \ --seeds 11@/ip4/127.0.0.1/udp/10011/ip4/127.0.0.1/tcp/10011 \ sip --addr 127.0.0.1:5060 \ No newline at end of file diff --git a/servers/media-server/scripts/media_webrtc.sh b/servers/media-server/scripts/media_webrtc.sh index b7501c14..da781a20 100755 --- a/servers/media-server/scripts/media_webrtc.sh +++ b/servers/media-server/scripts/media_webrtc.sh @@ -2,6 +2,6 @@ cargo run --package atm0s-media-server -- \ --node-id 41 \ --http-port 8041 \ --sdn-port 10041 \ ---sdn-group group1 \ +--sdn-zone zone1 \ --seeds 11@/ip4/127.0.0.1/udp/10011/ip4/127.0.0.1/tcp/10011 \ webrtc \ No newline at end of file diff --git a/servers/media-server/src/main.rs b/servers/media-server/src/main.rs index 25ce090d..74e13fbd 100644 --- a/servers/media-server/src/main.rs +++ b/servers/media-server/src/main.rs @@ -8,7 +8,7 @@ mod server; use cluster::{ atm0s_sdn::SystemTimer, implement::{NodeAddr, NodeId, ServerSdn, ServerSdnConfig}, - CONNECTOR_SERVICE, GLOBAL_GATEWAY_SERVICE, INNER_GATEWAY_SERVICE, MEDIA_SERVER_SERVICE, + CONNECTOR_SERVICE, GATEWAY_SERVICE, MEDIA_SERVER_SERVICE, }; #[cfg(feature = "connector")] @@ -26,8 +26,6 @@ use server::webrtc::run_webrtc_server; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; -use crate::server::gateway::GatewayMode; - /// Media Server #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -44,9 +42,9 @@ struct Args { #[arg(env, long, default_value_t = 0)] sdn_port: u16, - /// Sdn group + /// Sdn Zone #[arg(env, long, default_value = "local")] - sdn_group: String, + sdn_zone: String, /// Current Node ID #[arg(env, long, default_value_t = 1)] @@ -105,63 +103,51 @@ async fn main() { #[cfg(feature = "gateway")] Servers::Gateway(opts) => { use server::MediaServerContext; - match opts.mode { - GatewayMode::Global => { - config.local_tags = vec!["gateway-global".to_string()]; - config.connect_tags = vec!["gateway-global".to_string()]; - } - GatewayMode::Inner => { - config.local_tags = vec![format!("gateway-inner-{}", args.sdn_group)]; - config.connect_tags = vec!["gateway-global".to_string(), format!("gateway-inner-{}", args.sdn_group)]; - } - } + config.local_tags = vec![format!("gateway-zone-{}", args.sdn_zone), "gateway".to_string()]; + config.connect_tags = vec!["gateway-global".to_string()]; let token = Arc::new(cluster::implement::jwt_static::JwtStaticToken::new(&args.secret)); let ctx = MediaServerContext::<()>::new(args.node_id, 0, Arc::new(SystemTimer()), token.clone(), token); - let rpc_service_id = match opts.mode { - GatewayMode::Inner => INNER_GATEWAY_SERVICE, - GatewayMode::Global => GLOBAL_GATEWAY_SERVICE, - }; - let (cluster, rpc_endpoint) = ServerSdn::new(args.node_id, args.sdn_port, rpc_service_id, config).await; - if let Err(e) = run_gateway_server(args.http_port, args.http_tls, opts, ctx, cluster, rpc_endpoint).await { + let (cluster, rpc_endpoint, pubsub) = ServerSdn::new(args.node_id, args.sdn_port, GATEWAY_SERVICE, config).await; + if let Err(e) = run_gateway_server(args.http_port, args.http_tls, &args.sdn_zone, opts, ctx, cluster, rpc_endpoint, pubsub).await { log::error!("[GatewayServer] error {}", e); } } #[cfg(feature = "webrtc")] Servers::Webrtc(opts) => { use server::MediaServerContext; - config.local_tags = vec![format!("media-webrtc-{}", args.sdn_group)]; - config.connect_tags = vec![format!("gateway-inner-{}", args.sdn_group)]; + config.local_tags = vec![format!("media-webrtc-{}", args.sdn_zone)]; + config.connect_tags = vec![format!("gateway-zone-{}", args.sdn_zone)]; let token = Arc::new(cluster::implement::jwt_static::JwtStaticToken::new(&args.secret)); let ctx = MediaServerContext::new(args.node_id, opts.max_conn, Arc::new(SystemTimer()), token.clone(), token); - let (cluster, rpc_endpoint) = ServerSdn::new(args.node_id, args.sdn_port, MEDIA_SERVER_SERVICE, config).await; - if let Err(e) = run_webrtc_server(args.http_port, args.http_tls, opts, ctx, cluster, rpc_endpoint).await { + let (cluster, rpc_endpoint, _pubsub) = ServerSdn::new(args.node_id, args.sdn_port, MEDIA_SERVER_SERVICE, config).await; + if let Err(e) = run_webrtc_server(args.http_port, args.http_tls, &args.sdn_zone, opts, ctx, cluster, rpc_endpoint).await { log::error!("[WebrtcServer] error {}", e); } } #[cfg(feature = "rtmp")] Servers::Rtmp(opts) => { use server::MediaServerContext; - config.local_tags = vec![format!("media-rtmp-{}", args.sdn_group)]; - config.connect_tags = vec![format!("gateway-inner-{}", args.sdn_group)]; + config.local_tags = vec![format!("media-rtmp-{}", args.sdn_zone)]; + config.connect_tags = vec![format!("gateway-zone-{}", args.sdn_zone)]; let token = Arc::new(cluster::implement::jwt_static::JwtStaticToken::new(&args.secret)); let ctx = MediaServerContext::new(args.node_id, opts.max_conn, Arc::new(SystemTimer()), token.clone(), token); - let (cluster, rpc_endpoint) = ServerSdn::new(args.node_id, args.sdn_port, MEDIA_SERVER_SERVICE, config).await; - if let Err(e) = run_rtmp_server(args.http_port, args.http_tls, opts, ctx, cluster, rpc_endpoint).await { + let (cluster, rpc_endpoint, _pubsub) = ServerSdn::new(args.node_id, args.sdn_port, MEDIA_SERVER_SERVICE, config).await; + if let Err(e) = run_rtmp_server(args.http_port, args.http_tls, &args.sdn_zone, opts, ctx, cluster, rpc_endpoint).await { log::error!("[RtmpServer] error {}", e); } } #[cfg(feature = "sip")] Servers::Sip(opts) => { use server::MediaServerContext; - config.local_tags = vec![format!("media-sip-{}", args.sdn_group)]; - config.connect_tags = vec![format!("gateway-inner-{}", args.sdn_group)]; + config.local_tags = vec![format!("media-sip-{}", args.sdn_zone)]; + config.connect_tags = vec![format!("gateway-zone-{}", args.sdn_zone)]; let token = Arc::new(cluster::implement::jwt_static::JwtStaticToken::new(&args.secret)); let ctx = MediaServerContext::new(args.node_id, opts.max_conn, Arc::new(SystemTimer()), token.clone(), token); - let (cluster, rpc_endpoint) = ServerSdn::new(args.node_id, args.sdn_port, MEDIA_SERVER_SERVICE, config).await; + let (cluster, rpc_endpoint, _pubsub) = ServerSdn::new(args.node_id, args.sdn_port, MEDIA_SERVER_SERVICE, config).await; if let Err(e) = run_sip_server(args.http_port, args.http_tls, opts, ctx, cluster, rpc_endpoint).await { log::error!("[RtmpServer] error {}", e); } @@ -169,12 +155,12 @@ async fn main() { #[cfg(feature = "connector")] Servers::Connector(opts) => { use server::MediaServerContext; - config.local_tags = vec![format!("connector-{}", args.sdn_group)]; - config.connect_tags = vec![format!("gateway-inner-{}", args.sdn_group)]; + config.local_tags = vec![format!("connector-{}", args.sdn_zone)]; + config.connect_tags = vec![format!("gateway-zone-{}", args.sdn_zone)]; let token = Arc::new(cluster::implement::jwt_static::JwtStaticToken::new(&args.secret)); let ctx = MediaServerContext::new(args.node_id, opts.max_conn, Arc::new(SystemTimer()), token.clone(), token); - let (cluster, rpc_endpoint) = ServerSdn::new(args.node_id, args.sdn_port, CONNECTOR_SERVICE, config).await; + let (cluster, rpc_endpoint, _pubsub) = ServerSdn::new(args.node_id, args.sdn_port, CONNECTOR_SERVICE, config).await; if let Err(e) = run_connector_server(args.http_port, args.http_tls, opts, ctx, cluster, rpc_endpoint).await { log::error!("[ConnectorServer] error {}", e); } diff --git a/servers/media-server/src/server/gateway.rs b/servers/media-server/src/server/gateway.rs index ce80ef79..5e616b34 100644 --- a/servers/media-server/src/server/gateway.rs +++ b/servers/media-server/src/server/gateway.rs @@ -3,23 +3,24 @@ use std::{sync::Arc, time::Duration}; use async_std::stream::StreamExt; use clap::Parser; use cluster::{ + atm0s_sdn::{Publisher, PubsubSdk}, implement::NodeId, rpc::{ - gateway::{NodePing, NodePong, QueryBestNodesResponse}, + gateway::{NodePing, QueryBestNodesResponse}, general::{MediaEndpointCloseRequest, MediaEndpointCloseResponse, MediaSessionProtocol, NodeInfo, ServerType}, webrtc::{WebrtcPatchRequest, WebrtcPatchResponse, WebrtcRemoteIceRequest, WebrtcRemoteIceResponse}, - RpcEmitter, RpcEndpoint, RpcRequest, RPC_MEDIA_ENDPOINT_CLOSE, RPC_NODE_PING, RPC_WEBRTC_CONNECT, RPC_WEBRTC_ICE, RPC_WEBRTC_PATCH, RPC_WHEP_CONNECT, RPC_WHIP_CONNECT, + RpcEmitter, RpcEndpoint, RpcRequest, RPC_MEDIA_ENDPOINT_CLOSE, RPC_WEBRTC_CONNECT, RPC_WEBRTC_ICE, RPC_WEBRTC_PATCH, RPC_WHEP_CONNECT, RPC_WHIP_CONNECT, }, - Cluster, ClusterEndpoint, GLOBAL_GATEWAY_SERVICE, INNER_GATEWAY_SERVICE, MEDIA_SERVER_SERVICE, + Cluster, ClusterEndpoint, MEDIA_SERVER_SERVICE, }; use futures::{select, FutureExt}; -use media_utils::{SystemTimer, Timer, F32}; +use media_utils::{hash_str, SystemTimer, Timer, F32}; use metrics::describe_counter; use metrics_dashboard::{build_dashboard_route, DashboardOptions}; use poem::{web::Json, Route}; use poem_openapi::OpenApiService; -use crate::rpc::http::HttpRpcServer; +use crate::{rpc::http::HttpRpcServer, server::gateway::logic::RouteResult}; #[cfg(feature = "embed-samples")] use crate::rpc::http::EmbeddedFilesEndpoint; @@ -39,7 +40,6 @@ use self::{ rpc::{cluster::GatewayClusterRpc, http::GatewayHttpApis, RpcEvent}, }; -pub use self::logic::GatewayMode; use super::MediaServerContext; mod ip2location; @@ -54,14 +54,6 @@ const GATEWAY_SESSIONS_CONNECT_ERROR: &str = "gateway.sessions.connect.error"; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] pub struct GatewayArgs { - /// Gateway mode - #[arg(value_enum, env, long, default_value_t = GatewayMode::Inner)] - pub mode: GatewayMode, - - /// Gateway group, only set if mode is Inner - #[arg(env, long, default_value = "")] - pub group: String, - /// lat location #[arg(env, long, default_value_t = 0.0)] pub lat: f32, @@ -75,7 +67,16 @@ pub struct GatewayArgs { pub geoip_db: String, } -pub async fn run_gateway_server(http_port: u16, http_tls: bool, opts: GatewayArgs, ctx: MediaServerContext<()>, cluster: C, rpc_endpoint: RPC) -> Result<(), &'static str> +pub async fn run_gateway_server( + http_port: u16, + http_tls: bool, + zone: &str, + opts: GatewayArgs, + ctx: MediaServerContext<()>, + cluster: C, + rpc_endpoint: RPC, + pubsub: PubsubSdk, +) -> Result<(), &'static str> where C: Cluster + Send + 'static, CR: ClusterEndpoint + Send + 'static, @@ -86,10 +87,7 @@ where let node_id = cluster.node_id(); let mut rpc_endpoint = GatewayClusterRpc::new(rpc_endpoint); let mut http_server: HttpRpcServer = crate::rpc::http::HttpRpcServer::new(http_port, http_tls); - let ip2location = match opts.mode { - GatewayMode::Global => Some(ip2location::Ip2Location::new(&opts.geoip_db)), - GatewayMode::Inner => None, - }; + let ip2location = ip2location::Ip2Location::new(&opts.geoip_db); let timer = Arc::new(SystemTimer()); let api_service = OpenApiService::new(GatewayHttpApis, "Gateway Server", env!("CARGO_PKG_VERSION")).server("/"); @@ -122,25 +120,49 @@ where http_server.start(route, ctx.clone()).await; let mut tick = async_std::stream::interval(Duration::from_millis(100)); - let mut gateway_logic = GatewayLogic::new(opts.mode); + let mut gateway_logic = GatewayLogic::new(zone); let rpc_emitter = rpc_endpoint.emitter(); let mut gateway_feedback_tick = async_std::stream::interval(Duration::from_millis(2000)); - let dest_service_id = match opts.mode { - GatewayMode::Global => INNER_GATEWAY_SERVICE, - GatewayMode::Inner => MEDIA_SERVER_SERVICE, - }; + let gateway_zone_channel_pub = pubsub.create_publisher(hash_str(&format!("gateway-zone-{}", zone)) as u32); + let gateway_zone_channel = pubsub.create_consumer(hash_str(&format!("gateway-zone-{}", zone)) as u32, None); + let gateway_channel_pub = pubsub.create_publisher(hash_str("gateway") as u32); + let gateway_channel = pubsub.create_consumer(hash_str("gateway") as u32, None); loop { let rpc = select! { _ = tick.next().fuse() => { gateway_logic.on_tick(timer.now_ms()); continue; - } - _ = gateway_feedback_tick.next().fuse() => { - if matches!(opts.mode, GatewayMode::Inner) { - ping_global_gateway(&gateway_logic, &opts.group, (F32::<2>::new(opts.lat), F32::<2>::new(opts.lng)), node_id, &rpc_emitter); + }, + e = gateway_zone_channel.recv().fuse() => match e { + Some((_, from, _, data)) => { + if from == node_id { + continue; + } + if let Ok(ping) = bincode::deserialize(&data) { + log::info!("[Gateway] node ping {:?}", ping); + gateway_logic.on_node_ping(timer.now_ms(), &ping); + } + continue; + }, + None => { + continue; } - + }, + e = gateway_channel.recv().fuse() => match e { + Some((_, _, _, data)) => { + if let Ok(ping) = bincode::deserialize(&data) { + log::info!("[Gateway] gateway ping {:?}", ping); + gateway_logic.on_gateway_ping(timer.now_ms(), &ping); + } + continue; + }, + None => { + continue; + } + }, + _ = gateway_feedback_tick.next().fuse() => { + ping_other_gateways(&gateway_logic, zone, (F32::<2>::new(opts.lat), F32::<2>::new(opts.lng)), node_id, &gateway_channel_pub); continue; }, rpc = http_server.recv().fuse() => { @@ -154,12 +176,13 @@ where match rpc { RpcEvent::NodePing(req) => { log::info!("[Gateway] node ping {:?}", req.param()); - req.answer(Ok(gateway_logic.on_ping(timer.now_ms(), req.param()))); + gateway_zone_channel_pub.send(bincode::serialize(req.param()).expect("Should serialize").into()); + req.answer(Ok(gateway_logic.on_node_ping(timer.now_ms(), req.param()))); } - RpcEvent::BestNodest(req) => { + RpcEvent::BestNodes(req) => { log::info!("[Gateway] best nodes {:?}", req.param()); - let nodes = gateway_logic.best_nodes( - ip2location.as_ref().map(|f| f.get_location(&req.param().ip_addr)).flatten(), + let route_res = gateway_logic.best_nodes( + ip2location.get_location(&req.param().ip_addr), match req.param().protocol { MediaSessionProtocol::Rtmp => ServiceType::Rtmp, MediaSessionProtocol::Sip => ServiceType::Sip, @@ -171,11 +194,15 @@ where 80, req.param().size, ); - req.answer(Ok(QueryBestNodesResponse { nodes })); + if let RouteResult::OtherNode { nodes, service_id } = route_res { + req.answer(Ok(QueryBestNodesResponse { nodes, service_id })); + } else { + req.answer(Err("NOT_FOUND")); + } } RpcEvent::WhipConnect(req) => { log::info!("[Gateway] whip connect compressed_sdp: {:?}", req.param().compressed_sdp.as_ref().map(|sdp| sdp.len())); - let location = ip2location.as_ref().map(|f| f.get_location(&req.param().ip_addr)).flatten(); + let location = ip2location.get_location(&req.param().ip_addr); webrtc_route::route_to_node( rpc_emitter.clone(), timer.clone(), @@ -189,12 +216,11 @@ where &req.param().user_agent.clone(), req.param().session_uuid, req, - dest_service_id, ); } RpcEvent::WhepConnect(req) => { log::info!("[Gateway] whep connect compressed_sdp: {:?}", req.param().compressed_sdp.as_ref().map(|sdp| sdp.len())); - let location = ip2location.as_ref().map(|f| f.get_location(&req.param().ip_addr)).flatten(); + let location = ip2location.get_location(&req.param().ip_addr); webrtc_route::route_to_node( rpc_emitter.clone(), timer.clone(), @@ -208,12 +234,11 @@ where &req.param().user_agent.clone(), req.param().session_uuid, req, - dest_service_id, ); } RpcEvent::WebrtcConnect(req) => { log::info!("[Gateway] webrtc connect compressed_sdp: {:?}", req.param().compressed_sdp.as_ref().map(|sdp| sdp.len())); - let location = ip2location.as_ref().map(|f| f.get_location(&req.param().ip_addr)).flatten(); + let location = ip2location.get_location(&req.param().ip_addr); webrtc_route::route_to_node( rpc_emitter.clone(), timer.clone(), @@ -227,7 +252,6 @@ where &req.param().user_agent.clone(), req.param().session_uuid, req, - dest_service_id, ); } RpcEvent::WebrtcRemoteIce(req) => { @@ -273,21 +297,16 @@ where } } -fn ping_global_gateway(logic: &GatewayLogic, group: &str, location: (F32<2>, F32<2>), node_id: NodeId, rpc_emitter: &EMITTER) { +fn ping_other_gateways(logic: &GatewayLogic, zone: &str, location: (F32<2>, F32<2>), node_id: NodeId, publisher: &Publisher) { let stats = logic.stats(); let req = NodePing { node_id, - group: group.to_string(), + zone: zone.to_string(), location: Some(location), rtmp: stats.rtmp, sip: stats.sip, webrtc: stats.webrtc, }; - let rpc_emitter = rpc_emitter.clone(); - async_std::task::spawn(async move { - if let Err(e) = rpc_emitter.request::<_, NodePong>(GLOBAL_GATEWAY_SERVICE, None, RPC_NODE_PING, req, 1000).await { - log::error!("[Gateway] ping global gateway error {:?}", e); - } - }); + publisher.send(bincode::serialize(&req).expect("Should serialize").into()); } diff --git a/servers/media-server/src/server/gateway/logic.rs b/servers/media-server/src/server/gateway/logic.rs index e91b6445..6c34d01e 100644 --- a/servers/media-server/src/server/gateway/logic.rs +++ b/servers/media-server/src/server/gateway/logic.rs @@ -1,28 +1,30 @@ use std::collections::HashMap; -use clap::ValueEnum; use cluster::{ implement::NodeId, rpc::gateway::{NodePing, NodePong, ServiceInfo}, }; use media_utils::F32; +use self::{global_registry::ServiceGlobalRegistry, inner_registry::ServiceInnerRegistry}; + mod global_registry; mod inner_registry; +#[derive(Debug, PartialEq, Eq)] +pub enum RouteResult { + NotFound, + LocalNode, + OtherNode { nodes: Vec, service_id: u8 }, +} + trait ServiceRegistry { fn on_tick(&mut self, now_ms: u64); - fn on_ping(&mut self, now_ms: u64, group: &str, location: Option<(F32<2>, F32<2>)>, node_id: NodeId, usage: u8, live: u32, max: u32); - fn best_nodes(&mut self, location: Option<(F32<2>, F32<2>)>, max_usage: u8, max_usage_fallback: u8, size: usize) -> Vec; + fn on_ping(&mut self, now_ms: u64, zone: &str, location: Option<(F32<2>, F32<2>)>, node_id: NodeId, usage: u8, live: u32, max: u32); + fn best_nodes(&mut self, location: Option<(F32<2>, F32<2>)>, max_usage: u8, max_usage_fallback: u8, size: usize) -> RouteResult; fn stats(&self) -> ServiceInfo; } -#[derive(ValueEnum, Clone, Copy, Debug)] -pub enum GatewayMode { - Global, - Inner, -} - /// Represents the type of service. #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] pub enum ServiceType { @@ -39,19 +41,28 @@ pub struct GatewayStats { /// Represents the gateway logic for handling node pings and managing services. pub struct GatewayLogic { - mode: GatewayMode, - services: HashMap>, + zone: String, + global_gateways: HashMap, + inner_services: HashMap, } impl GatewayLogic { /// Creates a new instance of `GatewayLogic`. - pub fn new(mode: GatewayMode) -> Self { - Self { mode, services: Default::default() } + pub fn new(zone: &str) -> Self { + Self { + zone: zone.to_string(), + global_gateways: Default::default(), + inner_services: Default::default(), + } } /// Handles the tick event. pub fn on_tick(&mut self, now_ms: u64) { - for (_typ, service) in self.services.iter_mut() { + for (_typ, service) in self.inner_services.iter_mut() { + service.on_tick(now_ms); + } + + for (_typ, service) in self.global_gateways.iter_mut() { service.on_tick(now_ms); } } @@ -66,21 +77,42 @@ impl GatewayLogic { /// # Returns /// /// A `NodePong` struct with a success flag indicating the success of the ping operation. - pub fn on_ping(&mut self, now_ms: u64, ping: &NodePing) -> NodePong { + pub fn on_node_ping(&mut self, now_ms: u64, ping: &NodePing) -> NodePong { if let Some(meta) = &ping.webrtc { - self.on_node_ping_service(now_ms, ping.node_id, ServiceType::Webrtc, &ping.group, ping.location, meta.usage, meta.live, meta.max); + self.on_node_ping_service(now_ms, ping.node_id, ServiceType::Webrtc, &ping.zone, ping.location, meta.usage, meta.live, meta.max); } if let Some(meta) = &ping.rtmp { - self.on_node_ping_service(now_ms, ping.node_id, ServiceType::Rtmp, &ping.group, ping.location, meta.usage, meta.live, meta.max); + self.on_node_ping_service(now_ms, ping.node_id, ServiceType::Rtmp, &ping.zone, ping.location, meta.usage, meta.live, meta.max); } if let Some(meta) = &ping.sip { - self.on_node_ping_service(now_ms, ping.node_id, ServiceType::Sip, &ping.group, ping.location, meta.usage, meta.live, meta.max); + self.on_node_ping_service(now_ms, ping.node_id, ServiceType::Sip, &ping.zone, ping.location, meta.usage, meta.live, meta.max); } NodePong { success: true } } + /// Handles the ping event for a gateway. + /// + /// # Arguments + /// + /// * `now_ms` - The current timestamp in milliseconds. + /// * `ping` - A reference to a `NodePing` struct containing information about the ping. + /// + pub fn on_gateway_ping(&mut self, now_ms: u64, ping: &NodePing) { + if let Some(meta) = &ping.webrtc { + self.on_gateway_ping_service(now_ms, ping.node_id, ServiceType::Webrtc, &ping.zone, ping.location, meta.usage, meta.live, meta.max); + } + if let Some(meta) = &ping.rtmp { + self.on_gateway_ping_service(now_ms, ping.node_id, ServiceType::Rtmp, &ping.zone, ping.location, meta.usage, meta.live, meta.max); + } + if let Some(meta) = &ping.sip { + self.on_gateway_ping_service(now_ms, ping.node_id, ServiceType::Sip, &ping.zone, ping.location, meta.usage, meta.live, meta.max); + } + } + /// Returns the best nodes for a service. /// + /// First we will check if we need to route to other gateway nodes, if not we will check in local + /// /// # Arguments /// /// * `service` - The type of service. @@ -91,11 +123,20 @@ impl GatewayLogic { /// # Returns /// /// A vector of `NodeId` representing the best nodes for the service. - pub fn best_nodes(&mut self, location: Option<(F32<2>, F32<2>)>, service: ServiceType, max_usage: u8, max_usage_fallback: u8, size: usize) -> Vec { - self.services + pub fn best_nodes(&mut self, location: Option<(F32<2>, F32<2>)>, service: ServiceType, max_usage: u8, max_usage_fallback: u8, size: usize) -> RouteResult { + if let Some(service) = self.global_gateways.get_mut(&service) { + match service.best_nodes(location, max_usage, max_usage_fallback, size) { + RouteResult::OtherNode { nodes, service_id } => { + return RouteResult::OtherNode { nodes, service_id }; + } + _ => {} + } + } + + self.inner_services .get_mut(&service) .map(|s| s.best_nodes(location, max_usage, max_usage_fallback, size)) - .unwrap_or_else(|| vec![]) + .unwrap_or_else(|| RouteResult::NotFound) } /// Handles the ping event for a specific service of a node. @@ -107,12 +148,27 @@ impl GatewayLogic { /// * `service` - The type of service. /// * `usage` - The usage value. /// * `max` - The maximum value. - fn on_node_ping_service(&mut self, now_ms: u64, node_id: NodeId, service: ServiceType, group: &str, location: Option<(F32<2>, F32<2>)>, usage: u8, live: u32, max: u32) { - let service = self.services.entry(service).or_insert_with(|| match self.mode { - GatewayMode::Global => Box::new(global_registry::ServiceGlobalRegistry::new(service)), - GatewayMode::Inner => Box::new(inner_registry::ServiceInnerRegistry::new(service)), - }); - service.on_ping(now_ms, group, location, node_id, usage, live, max); + fn on_node_ping_service(&mut self, now_ms: u64, node_id: NodeId, service: ServiceType, zone: &str, location: Option<(F32<2>, F32<2>)>, usage: u8, live: u32, max: u32) { + if self.zone.eq(zone) { + let service = self.inner_services.entry(service).or_insert_with(|| ServiceInnerRegistry::new(service)); + service.on_ping(now_ms, zone, location, node_id, usage, live, max); + } else { + log::warn!("ping from wrong zone {} vs current zone {}", zone, self.zone); + } + } + + /// Handles the ping event for a specific service of a gateway. + /// + /// # Arguments + /// + /// * `now_ms` - The current timestamp in milliseconds. + /// * `node_id` - The ID of the node. + /// * `service` - The type of service. + /// * `usage` - The usage value. + /// * `max` - The maximum value. + fn on_gateway_ping_service(&mut self, now_ms: u64, node_id: NodeId, service: ServiceType, zone: &str, location: Option<(F32<2>, F32<2>)>, usage: u8, live: u32, max: u32) { + let service = self.global_gateways.entry(service).or_insert_with(|| ServiceGlobalRegistry::new(&self.zone, service)); + service.on_ping(now_ms, zone, location, node_id, usage, live, max); } /// Returns the statistics for the gateway server. @@ -125,7 +181,7 @@ impl GatewayLogic { let sip = None; let mut webrtc = None; - for (service, registry) in &self.services { + for (service, registry) in &self.inner_services { match service { ServiceType::Webrtc => webrtc = Some(registry.stats()), // ServiceType::Rtmp => rtmp = Some(registry.stats()), //TODO support rtmp @@ -142,26 +198,27 @@ impl GatewayLogic { mod tests { use cluster::rpc::gateway::{NodePing, ServiceInfo}; - use crate::server::gateway::logic::{GatewayLogic, GatewayMode}; + use crate::server::gateway::logic::GatewayLogic; #[test] fn test_gateway_logic_creation() { - let gateway_logic = GatewayLogic::new(GatewayMode::Inner); - assert_eq!(gateway_logic.services.len(), 0); + let gateway_logic = GatewayLogic::new("zone1"); + assert_eq!(gateway_logic.inner_services.len(), 0); + assert_eq!(gateway_logic.global_gateways.len(), 0); } #[test] fn test_on_tick_without_services() { - let mut gateway_logic = GatewayLogic::new(GatewayMode::Inner); + let mut gateway_logic = GatewayLogic::new("zone1"); gateway_logic.on_tick(0); } #[test] fn test_on_ping_with_valid_node_ping() { - let mut gateway_logic = GatewayLogic::new(GatewayMode::Inner); + let mut gateway_logic = GatewayLogic::new("zone1"); let node_ping = NodePing { node_id: 1, - group: "".to_string(), + zone: "zone1".to_string(), location: None, webrtc: Some(ServiceInfo { usage: 50, @@ -179,24 +236,24 @@ mod tests { }), sip: None, }; - let node_pong = gateway_logic.on_ping(0, &node_ping); + let node_pong = gateway_logic.on_node_ping(0, &node_ping); assert_eq!(node_pong.success, true); - assert_eq!(gateway_logic.services.len(), 2); + assert_eq!(gateway_logic.inner_services.len(), 2); } #[test] fn test_on_ping_with_no_services() { - let mut gateway_logic = GatewayLogic::new(GatewayMode::Inner); + let mut gateway_logic = GatewayLogic::new("zone1"); let node_ping = NodePing { node_id: 1, - group: "".to_string(), + zone: "zone1".to_string(), location: None, webrtc: None, rtmp: None, sip: None, }; - let node_pong = gateway_logic.on_ping(0, &node_ping); + let node_pong = gateway_logic.on_node_ping(0, &node_ping); assert_eq!(node_pong.success, true); } } diff --git a/servers/media-server/src/server/gateway/logic/global_registry.rs b/servers/media-server/src/server/gateway/logic/global_registry.rs index 0fbae6b4..319efd99 100644 --- a/servers/media-server/src/server/gateway/logic/global_registry.rs +++ b/servers/media-server/src/server/gateway/logic/global_registry.rs @@ -1,10 +1,10 @@ use std::collections::HashMap; -use cluster::{implement::NodeId, rpc::gateway::ServiceInfo}; +use cluster::{implement::NodeId, rpc::gateway::ServiceInfo, GATEWAY_SERVICE}; use media_utils::F32; use metrics::{describe_gauge, gauge}; -use super::{ServiceRegistry, ServiceType}; +use super::{RouteResult, ServiceRegistry, ServiceType}; const NODE_TIMEOUT_MS: u64 = 10_000; @@ -24,6 +24,7 @@ fn lat_lng_distance(from: &(F32<2>, F32<2>), to: &(F32<2>, F32<2>)) -> f32 { #[derive(Debug, Default)] struct Zone { + zone: String, location: (F32<2>, F32<2>), nodes: HashMap, usage: u8, @@ -34,19 +35,21 @@ struct Zone { #[derive(Debug)] pub(super) struct ServiceGlobalRegistry { + zone: String, metric_live: String, metric_max: String, zones: HashMap, } impl ServiceGlobalRegistry { - pub fn new(service: ServiceType) -> Self { + pub fn new(zone: &str, service: ServiceType) -> Self { let metric_live = format!("gateway.sessions.{:?}.live", service); let metric_max = format!("gateway.sessions.{:?}.max", service); describe_gauge!(metric_live.clone(), format!("Current live {:?} sessions number", service)); describe_gauge!(metric_max.clone(), format!("Max live {:?} sessions number", service)); Self { + zone: zone.to_string(), metric_live, metric_max, zones: Default::default(), @@ -87,10 +90,10 @@ impl ServiceRegistry for ServiceGlobalRegistry { } /// we save node or create new, then sort by ascending order - fn on_ping(&mut self, now_ms: u64, group: &str, location: Option<(F32<2>, F32<2>)>, node_id: NodeId, usage: u8, live: u32, max: u32) { + fn on_ping(&mut self, now_ms: u64, zone: &str, location: Option<(F32<2>, F32<2>)>, node_id: NodeId, usage: u8, live: u32, max: u32) { let location = location.unwrap_or((F32::<2>::new(0.0), F32::<2>::new(0.0))); - if let Some(slot) = self.zones.get_mut(group) { + if let Some(slot) = self.zones.get_mut(zone) { slot.nodes.insert(node_id, now_ms); slot.usage = usage; slot.live = live; @@ -98,8 +101,9 @@ impl ServiceRegistry for ServiceGlobalRegistry { slot.last_updated = now_ms; } else { self.zones.insert( - group.to_string(), + zone.to_string(), Zone { + zone: zone.to_string(), nodes: HashMap::from([(node_id, now_ms)]), location, usage, @@ -112,7 +116,7 @@ impl ServiceRegistry for ServiceGlobalRegistry { } /// we get first with max_usage, if not enough => using max_usage_fallback - fn best_nodes(&mut self, location: Option<(F32<2>, F32<2>)>, max_usage: u8, max_usage_fallback: u8, size: usize) -> Vec { + fn best_nodes(&mut self, location: Option<(F32<2>, F32<2>)>, max_usage: u8, max_usage_fallback: u8, size: usize) -> RouteResult { let location = location.unwrap_or((F32::<2>::new(0.0), F32::<2>::new(0.0))); //finding closest zone @@ -122,11 +126,15 @@ impl ServiceRegistry for ServiceGlobalRegistry { } if let Some(zone) = closest_zone { - let mut nodes = zone.nodes.keys().cloned().collect::>(); - nodes.truncate(size); - nodes + if zone.zone.eq(&self.zone) { + RouteResult::LocalNode + } else { + let mut nodes = zone.nodes.keys().cloned().collect::>(); + nodes.truncate(size); + RouteResult::OtherNode { nodes, service_id: GATEWAY_SERVICE } + } } else { - vec![] + RouteResult::NotFound } } @@ -143,32 +151,33 @@ mod tests { // ServiceGlobalRegistry can be created with default values #[test] fn test_service_registry_creation() { - let registry = ServiceGlobalRegistry::new(ServiceType::Webrtc); + let mut registry = ServiceGlobalRegistry::new("local", ServiceType::Webrtc); assert_eq!(registry.zones.len(), 0); + assert_eq!(registry.best_nodes(None, 80, 90, 1), RouteResult::NotFound); } // test with single zone and single gateway #[test] fn test_service_registry_single_zone_single_gateway() { - let mut registry = ServiceGlobalRegistry::new(ServiceType::Webrtc); + let mut registry = ServiceGlobalRegistry::new("zone1", ServiceType::Webrtc); let now_ms = 0; - let group = "group"; + let zone = "zone1"; let location = Some((F32::<2>::new(0.0), F32::<2>::new(0.0))); let node_id = 1; let usage = 0; let live = 0; let max = 10; - registry.on_ping(now_ms, group, location, node_id, usage, live, max); + registry.on_ping(now_ms, zone, location, node_id, usage, live, max); assert_eq!(registry.zones.len(), 1); - assert_eq!(registry.zones.get(group).unwrap().nodes.len(), 1); - assert_eq!(registry.zones.get(group).unwrap().nodes.get(&node_id).unwrap(), &now_ms); - assert_eq!(registry.zones.get(group).unwrap().usage, usage); - assert_eq!(registry.zones.get(group).unwrap().live, live); - assert_eq!(registry.zones.get(group).unwrap().max, max); - assert_eq!(registry.zones.get(group).unwrap().last_updated, now_ms); + assert_eq!(registry.zones.get(zone).unwrap().nodes.len(), 1); + assert_eq!(registry.zones.get(zone).unwrap().nodes.get(&node_id).unwrap(), &now_ms); + assert_eq!(registry.zones.get(zone).unwrap().usage, usage); + assert_eq!(registry.zones.get(zone).unwrap().live, live); + assert_eq!(registry.zones.get(zone).unwrap().max, max); + assert_eq!(registry.zones.get(zone).unwrap().last_updated, now_ms); - assert_eq!(registry.best_nodes(location, 60, 80, 1), vec![node_id]); + assert_eq!(registry.best_nodes(location, 60, 80, 1), RouteResult::LocalNode); registry.on_tick(now_ms + NODE_TIMEOUT_MS); assert_eq!(registry.zones.len(), 0); @@ -176,42 +185,42 @@ mod tests { #[test] fn test_service_fallback_max_usage() { - let mut registry = ServiceGlobalRegistry::new(ServiceType::Webrtc); + let mut registry = ServiceGlobalRegistry::new("zone1", ServiceType::Webrtc); let now_ms = 0; - let group = "group"; + let zone = "zone1"; let location = Some((F32::<2>::new(0.0), F32::<2>::new(0.0))); let node_id = 1; let usage = 70; let live = 9; let max = 10; - registry.on_ping(now_ms, group, location, node_id, usage, live, max); - assert_eq!(registry.best_nodes(location, 50, 60, 1), Vec::::new()); - assert_eq!(registry.best_nodes(location, 60, 80, 1), vec![node_id]); + registry.on_ping(now_ms, zone, location, node_id, usage, live, max); + assert_eq!(registry.best_nodes(location, 50, 60, 1), RouteResult::NotFound); + assert_eq!(registry.best_nodes(location, 60, 80, 1), RouteResult::LocalNode); } // test with gateway with max zero should return none #[test] fn test_service_registry_single_zone_single_gateway_with_max_zero() { - let mut registry = ServiceGlobalRegistry::new(ServiceType::Webrtc); + let mut registry = ServiceGlobalRegistry::new("zone1", ServiceType::Webrtc); let now_ms = 0; - let group = "group"; + let zone = "zone1"; let location = Some((F32::<2>::new(0.0), F32::<2>::new(0.0))); let node_id = 1; let usage = 0; let live = 0; let max = 0; - registry.on_ping(now_ms, group, location, node_id, usage, live, max); + registry.on_ping(now_ms, zone, location, node_id, usage, live, max); assert_eq!(registry.zones.len(), 1); - assert_eq!(registry.zones.get(group).unwrap().nodes.len(), 1); - assert_eq!(registry.zones.get(group).unwrap().nodes.get(&node_id).unwrap(), &now_ms); - assert_eq!(registry.zones.get(group).unwrap().usage, usage); - assert_eq!(registry.zones.get(group).unwrap().live, live); - assert_eq!(registry.zones.get(group).unwrap().max, max); - assert_eq!(registry.zones.get(group).unwrap().last_updated, now_ms); + assert_eq!(registry.zones.get(zone).unwrap().nodes.len(), 1); + assert_eq!(registry.zones.get(zone).unwrap().nodes.get(&node_id).unwrap(), &now_ms); + assert_eq!(registry.zones.get(zone).unwrap().usage, usage); + assert_eq!(registry.zones.get(zone).unwrap().live, live); + assert_eq!(registry.zones.get(zone).unwrap().max, max); + assert_eq!(registry.zones.get(zone).unwrap().last_updated, now_ms); - assert_eq!(registry.best_nodes(location, 60, 80, 1), Vec::::new()); + assert_eq!(registry.best_nodes(location, 60, 80, 1), RouteResult::NotFound); registry.on_tick(now_ms + NODE_TIMEOUT_MS); assert_eq!(registry.zones.len(), 0); @@ -220,9 +229,9 @@ mod tests { // test with single zone multi gateways #[test] fn test_service_registry_single_zone_multi_gateways() { - let mut registry = ServiceGlobalRegistry::new(ServiceType::Webrtc); + let mut registry = ServiceGlobalRegistry::new("zone1", ServiceType::Webrtc); let now_ms = 0; - let group = "group"; + let zone = "zone1"; let location = Some((F32::<2>::new(0.0), F32::<2>::new(0.0))); let node_id_1 = 1; let node_id_2 = 2; @@ -230,41 +239,40 @@ mod tests { let live = 0; let max = 10; - registry.on_ping(now_ms, group, location, node_id_1, usage, live, max); - registry.on_ping(now_ms, group, location, node_id_2, usage, live, max); + registry.on_ping(now_ms, zone, location, node_id_1, usage, live, max); + registry.on_ping(now_ms, zone, location, node_id_2, usage, live, max); assert_eq!(registry.zones.len(), 1); - assert_eq!(registry.zones.get(group).unwrap().nodes.len(), 2); - assert_eq!(registry.zones.get(group).unwrap().nodes.get(&node_id_1).unwrap(), &now_ms); - assert_eq!(registry.zones.get(group).unwrap().nodes.get(&node_id_2).unwrap(), &now_ms); - assert_eq!(registry.zones.get(group).unwrap().usage, usage); - assert_eq!(registry.zones.get(group).unwrap().live, live); - assert_eq!(registry.zones.get(group).unwrap().max, max); - assert_eq!(registry.zones.get(group).unwrap().last_updated, now_ms); + assert_eq!(registry.zones.get(zone).unwrap().nodes.len(), 2); + assert_eq!(registry.zones.get(zone).unwrap().nodes.get(&node_id_1).unwrap(), &now_ms); + assert_eq!(registry.zones.get(zone).unwrap().nodes.get(&node_id_2).unwrap(), &now_ms); + assert_eq!(registry.zones.get(zone).unwrap().usage, usage); + assert_eq!(registry.zones.get(zone).unwrap().live, live); + assert_eq!(registry.zones.get(zone).unwrap().max, max); + assert_eq!(registry.zones.get(zone).unwrap().last_updated, now_ms); - let mut best_nodes = registry.best_nodes(location, 60, 80, 2); - best_nodes.sort(); - assert_eq!(best_nodes, vec![node_id_1, node_id_2]); + let route_res = registry.best_nodes(location, 60, 80, 2); + assert_eq!(route_res, RouteResult::LocalNode); - registry.on_ping(1000, group, location, node_id_1, usage, live, max); + registry.on_ping(1000, zone, location, node_id_1, usage, live, max); //simulate timeout registry.on_tick(now_ms + NODE_TIMEOUT_MS); assert_eq!(registry.zones.len(), 1); - assert_eq!(registry.zones.get(group).unwrap().nodes.len(), 1); - assert_eq!(registry.zones.get(group).unwrap().nodes.get(&node_id_1).unwrap(), &1000); - assert_eq!(registry.zones.get(group).unwrap().usage, usage); - assert_eq!(registry.zones.get(group).unwrap().live, live); - assert_eq!(registry.zones.get(group).unwrap().max, max); - assert_eq!(registry.zones.get(group).unwrap().last_updated, 1000); + assert_eq!(registry.zones.get(zone).unwrap().nodes.len(), 1); + assert_eq!(registry.zones.get(zone).unwrap().nodes.get(&node_id_1).unwrap(), &1000); + assert_eq!(registry.zones.get(zone).unwrap().usage, usage); + assert_eq!(registry.zones.get(zone).unwrap().live, live); + assert_eq!(registry.zones.get(zone).unwrap().max, max); + assert_eq!(registry.zones.get(zone).unwrap().last_updated, 1000); } //test with multi zones and multi gateways #[test] fn test_service_registry_multi_zones_multi_gateways() { - let mut registry = ServiceGlobalRegistry::new(ServiceType::Webrtc); + let mut registry = ServiceGlobalRegistry::new("zone1", ServiceType::Webrtc); let now_ms = 0; - let group_1 = "group_1"; - let group_2 = "group_2"; + let zone_1 = "zone1"; + let zone_2 = "zone2"; let location_1 = Some((F32::<2>::new(0.0), F32::<2>::new(0.0))); let location_2 = Some((F32::<2>::new(1.0), F32::<2>::new(1.0))); let node_id_1 = 1; @@ -273,31 +281,35 @@ mod tests { let live = 0; let max = 10; - registry.on_ping(now_ms, group_1, location_1, node_id_1, usage, live, max); - registry.on_ping(now_ms, group_2, location_2, node_id_2, usage, live, max); + registry.on_ping(now_ms, zone_1, location_1, node_id_1, usage, live, max); + registry.on_ping(now_ms, zone_2, location_2, node_id_2, usage, live, max); assert_eq!(registry.zones.len(), 2); - assert_eq!(registry.zones.get(group_1).unwrap().nodes.len(), 1); - assert_eq!(registry.zones.get(group_1).unwrap().nodes.get(&node_id_1).unwrap(), &now_ms); - assert_eq!(registry.zones.get(group_1).unwrap().usage, usage); - assert_eq!(registry.zones.get(group_1).unwrap().live, live); - assert_eq!(registry.zones.get(group_1).unwrap().max, max); - assert_eq!(registry.zones.get(group_1).unwrap().last_updated, now_ms); - - assert_eq!(registry.zones.get(group_2).unwrap().nodes.len(), 1); - assert_eq!(registry.zones.get(group_2).unwrap().nodes.get(&node_id_2).unwrap(), &now_ms); - assert_eq!(registry.zones.get(group_2).unwrap().usage, usage); - assert_eq!(registry.zones.get(group_2).unwrap().live, live); - assert_eq!(registry.zones.get(group_2).unwrap().max, max); - assert_eq!(registry.zones.get(group_2).unwrap().last_updated, now_ms); - - let mut best_nodes = registry.best_nodes(location_1, 60, 80, 2); - best_nodes.sort(); - assert_eq!(best_nodes, vec![node_id_1]); - - let mut best_nodes = registry.best_nodes(location_2, 60, 80, 2); - best_nodes.sort(); - assert_eq!(best_nodes, vec![node_id_2]); + assert_eq!(registry.zones.get(zone_1).unwrap().nodes.len(), 1); + assert_eq!(registry.zones.get(zone_1).unwrap().nodes.get(&node_id_1).unwrap(), &now_ms); + assert_eq!(registry.zones.get(zone_1).unwrap().usage, usage); + assert_eq!(registry.zones.get(zone_1).unwrap().live, live); + assert_eq!(registry.zones.get(zone_1).unwrap().max, max); + assert_eq!(registry.zones.get(zone_1).unwrap().last_updated, now_ms); + + assert_eq!(registry.zones.get(zone_2).unwrap().nodes.len(), 1); + assert_eq!(registry.zones.get(zone_2).unwrap().nodes.get(&node_id_2).unwrap(), &now_ms); + assert_eq!(registry.zones.get(zone_2).unwrap().usage, usage); + assert_eq!(registry.zones.get(zone_2).unwrap().live, live); + assert_eq!(registry.zones.get(zone_2).unwrap().max, max); + assert_eq!(registry.zones.get(zone_2).unwrap().last_updated, now_ms); + + let route_res = registry.best_nodes(location_1, 60, 80, 2); + assert_eq!(route_res, RouteResult::LocalNode); + + let route_res = registry.best_nodes(location_2, 60, 80, 2); + assert_eq!( + route_res, + RouteResult::OtherNode { + nodes: vec![node_id_2], + service_id: GATEWAY_SERVICE + } + ); } #[test] diff --git a/servers/media-server/src/server/gateway/logic/inner_registry.rs b/servers/media-server/src/server/gateway/logic/inner_registry.rs index 460d34bf..b71b1c65 100644 --- a/servers/media-server/src/server/gateway/logic/inner_registry.rs +++ b/servers/media-server/src/server/gateway/logic/inner_registry.rs @@ -1,10 +1,10 @@ use std::cmp::Ordering; -use cluster::{implement::NodeId, rpc::gateway::ServiceInfo}; +use cluster::{implement::NodeId, rpc::gateway::ServiceInfo, MEDIA_SERVER_SERVICE}; use media_utils::F32; use metrics::{describe_gauge, gauge}; -use super::{ServiceRegistry, ServiceType}; +use super::{RouteResult, ServiceRegistry, ServiceType}; const NODE_TIMEOUT_MS: u64 = 10_000; @@ -79,7 +79,7 @@ impl ServiceRegistry for ServiceInnerRegistry { } /// we save node or create new, then sort by ascending order - fn on_ping(&mut self, now_ms: u64, _group: &str, _location: Option<(F32<2>, F32<2>)>, node_id: NodeId, usage: u8, live: u32, max: u32) { + fn on_ping(&mut self, now_ms: u64, _zone: &str, _location: Option<(F32<2>, F32<2>)>, node_id: NodeId, usage: u8, live: u32, max: u32) { if let Some(slot) = self.nodes.iter_mut().find(|s| s.node_id == node_id) { slot.usage = usage; slot.live = live; @@ -98,23 +98,23 @@ impl ServiceRegistry for ServiceInnerRegistry { } /// we get first with max_usage, if not enough => using max_usage_fallback - fn best_nodes(&mut self, _location: Option<(F32<2>, F32<2>)>, max_usage: u8, max_usage_fallback: u8, size: usize) -> Vec { - let mut res = vec![]; + fn best_nodes(&mut self, _location: Option<(F32<2>, F32<2>)>, max_usage: u8, max_usage_fallback: u8, size: usize) -> RouteResult { + let mut nodes = vec![]; for slot in self.nodes.iter().rev() { if slot.usage <= max_usage { - res.push(slot.node_id); - if res.len() == size { + nodes.push(slot.node_id); + if nodes.len() == size { break; } } } - if res.len() < size { + if nodes.len() < size { for slot in self.nodes.iter().rev() { if slot.usage <= max_usage_fallback { - if !res.contains(&slot.node_id) { - res.push(slot.node_id); - if res.len() == size { + if !nodes.contains(&slot.node_id) { + nodes.push(slot.node_id); + if nodes.len() == size { break; } } @@ -122,7 +122,14 @@ impl ServiceRegistry for ServiceInnerRegistry { } } - res + if nodes.is_empty() { + RouteResult::NotFound + } else { + RouteResult::OtherNode { + nodes, + service_id: MEDIA_SERVER_SERVICE, + } + } } fn stats(&self) -> ServiceInfo { @@ -261,8 +268,13 @@ mod tests { let result = registry.best_nodes(None, max_usage, max_usage_fallback, size); - assert_eq!(result.len(), 1); - assert_eq!(result[0], node_id1); + assert_eq!( + result, + RouteResult::OtherNode { + nodes: vec![node_id1], + service_id: MEDIA_SERVER_SERVICE + } + ); } #[test] @@ -286,10 +298,17 @@ mod tests { let size = 2; let mut result = registry.best_nodes(None, max_usage_fallback, max_usage, size); + if let RouteResult::OtherNode { nodes, service_id: _ } = &mut result { + nodes.sort(); + } - assert_eq!(result.len(), 2); - result.sort(); - assert_eq!(result, [node_id1, node_id2]); + assert_eq!( + result, + RouteResult::OtherNode { + nodes: vec![node_id1, node_id2], + service_id: MEDIA_SERVER_SERVICE + } + ); } #[test] diff --git a/servers/media-server/src/server/gateway/rpc.rs b/servers/media-server/src/server/gateway/rpc.rs index 5503cf34..f6c78c71 100644 --- a/servers/media-server/src/server/gateway/rpc.rs +++ b/servers/media-server/src/server/gateway/rpc.rs @@ -6,7 +6,7 @@ pub mod http; pub enum RpcEvent { NodePing(Box>), - BestNodest(Box>), + BestNodes(Box>), WhipConnect(Box>), WhepConnect(Box>), WebrtcConnect(Box>), diff --git a/servers/media-server/src/server/gateway/rpc/http.rs b/servers/media-server/src/server/gateway/rpc/http.rs index a5156be5..2ff8d971 100644 --- a/servers/media-server/src/server/gateway/rpc/http.rs +++ b/servers/media-server/src/server/gateway/rpc/http.rs @@ -59,7 +59,7 @@ impl GatewayHttpApis { size: size.0, }); data.0 - .send(RpcEvent::BestNodest(Box::new(req))) + .send(RpcEvent::BestNodes(Box::new(req))) .await .map_err(|_e| poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR))?; let res = rx.recv().await.map_err(|e| poem::Error::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; diff --git a/servers/media-server/src/server/gateway/webrtc_route.rs b/servers/media-server/src/server/gateway/webrtc_route.rs index 87220500..9ac0b9a3 100644 --- a/servers/media-server/src/server/gateway/webrtc_route.rs +++ b/servers/media-server/src/server/gateway/webrtc_route.rs @@ -19,7 +19,7 @@ use protocol::media_event_logs::{ use crate::server::gateway::{GATEWAY_SESSIONS_CONNECT_COUNT, GATEWAY_SESSIONS_CONNECT_ERROR}; -use super::logic::{GatewayLogic, ServiceType}; +use super::logic::{GatewayLogic, RouteResult, ServiceType}; async fn select_node(emitter: &EMITTER, node_ids: &[u32], service_id: u8) -> Option { let mut futures = Vec::new(); @@ -78,7 +78,7 @@ fn emit_endpoint_event(emitter: &EMITTER, 1000, ) .await - .log_error("Should ok"); + .log_error("Should send media-log-event to connector"); }); } @@ -95,7 +95,6 @@ pub fn route_to_node( user_agent: &str, session_uuid: u64, req: Box>, - dest_service_id: u8, ) where EMITTER: RpcEmitter + Send + Sync + 'static, Req: Into> + Send + Clone + 'static, @@ -109,18 +108,18 @@ pub fn route_to_node( }); emit_endpoint_event(&emitter, &timer, session_uuid, &ip.to_string(), version, event); - let nodes = gateway_logic.best_nodes(location, service, 60, 80, 3); - if !nodes.is_empty() { + let route_res = gateway_logic.best_nodes(location, service, 60, 80, 3); + if let RouteResult::OtherNode { nodes, service_id } = route_res { let rpc_emitter = emitter.clone(); let ip: String = ip.to_string(); let version = version.clone(); let param = req.param().clone(); async_std::task::spawn(async move { log::info!("[Gateway] connect => ping nodes {:?}", nodes); - let node_id = select_node(&rpc_emitter, &nodes, dest_service_id).await; + let node_id = select_node(&rpc_emitter, &nodes, service_id).await; if let Some(node_id) = node_id { log::info!("[Gateway] connect with selected node {:?}", node_id); - let res = rpc_emitter.request::(dest_service_id, Some(node_id), cmd, param, 5000).await; + let res = rpc_emitter.request::(service_id, Some(node_id), cmd, param, 5000).await; log::info!("[Gateway] webrtc connect res from media-server {:?}", res.as_ref().map(|_| ())); let event = if res.is_err() { counter!(GATEWAY_SESSIONS_CONNECT_ERROR).increment(1); diff --git a/servers/media-server/src/server/rtmp.rs b/servers/media-server/src/server/rtmp.rs index 0f7481c8..aafc452d 100644 --- a/servers/media-server/src/server/rtmp.rs +++ b/servers/media-server/src/server/rtmp.rs @@ -12,7 +12,7 @@ use cluster::{ general::{MediaEndpointCloseResponse, MediaSessionProtocol, NodeInfo, ServerType}, RpcEmitter, RpcEndpoint, RpcRequest, RPC_NODE_PING, }, - Cluster, ClusterEndpoint, INNER_GATEWAY_SERVICE, + Cluster, ClusterEndpoint, GATEWAY_SERVICE, }; use futures::{select, FutureExt}; use media_utils::ErrorDebugger; @@ -64,6 +64,7 @@ pub struct RtmpArgs { pub async fn run_rtmp_server( http_port: u16, http_tls: bool, + zone: &str, opts: RtmpArgs, ctx: MediaServerContext, mut cluster: C, @@ -119,7 +120,7 @@ where loop { let rpc = select! { _ = gateway_feedback_tick.next().fuse() => { - ping_gateway(&ctx, node_id, rtmp_port, &rpc_emitter); + ping_gateway(&ctx, node_id, zone, rtmp_port, &rpc_emitter); continue; }, rpc = http_server.recv().fuse() => { @@ -188,10 +189,10 @@ where } } -fn ping_gateway(ctx: &MediaServerContext, node_id: NodeId, rtmp_port: u16, rpc_emitter: &EMITTER) { +fn ping_gateway(ctx: &MediaServerContext, node_id: NodeId, zone: &str, rtmp_port: u16, rpc_emitter: &EMITTER) { let req = NodePing { node_id, - group: "".to_string(), + zone: zone.to_string(), location: None, rtmp: Some(ServiceInfo { usage: ((ctx.conns_live() * 100) / ctx.conns_max()) as u8, @@ -206,7 +207,7 @@ fn ping_gateway(ctx: &MediaServerContext(INNER_GATEWAY_SERVICE, None, RPC_NODE_PING, req, 1000).await { + if let Err(e) = rpc_emitter.request::<_, NodePong>(GATEWAY_SERVICE, None, RPC_NODE_PING, req, 1000).await { log::error!("[RtmpServer] ping gateway error {:?}", e); } }); diff --git a/servers/media-server/src/server/webrtc.rs b/servers/media-server/src/server/webrtc.rs index 71f14b3a..830d7c1c 100644 --- a/servers/media-server/src/server/webrtc.rs +++ b/servers/media-server/src/server/webrtc.rs @@ -12,7 +12,7 @@ use cluster::{ whip::WhipConnectResponse, RpcEmitter, RpcEndpoint, RpcReqRes, RpcRequest, RPC_NODE_PING, }, - BitrateControlMode, Cluster, ClusterEndpoint, ClusterEndpointPublishScope, ClusterEndpointSubscribeScope, MixMinusAudioMode, VerifyObject, INNER_GATEWAY_SERVICE, + BitrateControlMode, Cluster, ClusterEndpoint, ClusterEndpointPublishScope, ClusterEndpointSubscribeScope, MixMinusAudioMode, VerifyObject, GATEWAY_SERVICE, }; use futures::{select, FutureExt}; use media_utils::{ErrorDebugger, StringCompression, SystemTimer, Timer}; @@ -68,6 +68,7 @@ pub struct WebrtcArgs { pub async fn run_webrtc_server( http_port: u16, http_tls: bool, + zone: &str, _opts: WebrtcArgs, ctx: MediaServerContext, mut cluster: C, @@ -122,7 +123,7 @@ where loop { let rpc = select! { _ = gateway_feedback_tick.next().fuse() => { - ping_gateway(&ctx, node_id, &rpc_emitter); + ping_gateway(&ctx, node_id, zone, &rpc_emitter); continue; }, rpc = http_server.recv().fuse() => { @@ -387,10 +388,10 @@ where } } -fn ping_gateway(ctx: &MediaServerContext, node_id: NodeId, rpc_emitter: &EMITTER) { +fn ping_gateway(ctx: &MediaServerContext, node_id: NodeId, zone: &str, rpc_emitter: &EMITTER) { let req = NodePing { node_id, - group: "".to_string(), + zone: zone.to_string(), location: None, rtmp: None, sip: None, @@ -405,7 +406,7 @@ fn ping_gateway(ctx: &MediaServerContext(INNER_GATEWAY_SERVICE, None, RPC_NODE_PING, req, 1000).await { + if let Err(e) = rpc_emitter.request::<_, NodePong>(GATEWAY_SERVICE, None, RPC_NODE_PING, req, 1000).await { log::error!("[WebrtcServer] ping gateway error {:?}", e); } });