diff --git a/Cargo.lock b/Cargo.lock index f1d6f111..13eb9c5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1181,6 +1181,7 @@ dependencies = [ "media-server-utils", "num_enum 0.7.2", "sans-io-runtime", + "smallmap", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index d2fd8f7e..bc4ef8e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,3 +14,4 @@ atm0s-sdn = { path = "../atm0s-sdn/packages/runner" } convert-enum = "0.1" num_enum = "0.7" log = "0.4" +smallmap = "1.4" diff --git a/bin/src/http.rs b/bin/src/http.rs index f86bc16b..52397372 100644 --- a/bin/src/http.rs +++ b/bin/src/http.rs @@ -1,3 +1,5 @@ +use std::net::SocketAddr; + use media_server_protocol::endpoint::ClusterConnId; use media_server_protocol::transport::{RpcReq, RpcRes}; use poem::endpoint::StaticFilesEndpoint; @@ -51,7 +53,7 @@ pub async fn run_gateway_http_server(sender: Sender, R Ok(()) } -pub async fn run_media_http_server(sender: Sender, RpcRes>>) -> Result<(), Box> { +pub async fn run_media_http_server(port: u16, sender: Sender, RpcRes>>) -> Result<(), Box> { let api_service: OpenApiService<_, ()> = OpenApiService::new(api_media::MediaApis, "Media Server APIs", env!("CARGO_PKG_VERSION")).server("/"); let ui = api_service.swagger_ui(); let spec = api_service.spec(); @@ -63,6 +65,6 @@ pub async fn run_media_http_server(sender: Sender, Rpc .with(Cors::new()) .data(api_media::MediaServerCtx { sender }); - Server::new(TcpListener::bind("0.0.0.0:3000")).run(route).await?; + Server::new(TcpListener::bind(SocketAddr::new([0, 0, 0, 0].into(), port))).run(route).await?; Ok(()) } diff --git a/bin/src/http/api_media.rs b/bin/src/http/api_media.rs index 66e4e3f2..1642acbd 100644 --- a/bin/src/http/api_media.rs +++ b/bin/src/http/api_media.rs @@ -1,6 +1,7 @@ use media_server_protocol::{ endpoint::ClusterConnId, transport::{ + whep::{self, WhepConnectReq, WhepDeleteReq, WhepRemoteIceReq}, whip::{self, WhipConnectReq, WhipDeleteReq, WhipRemoteIceReq}, RpcReq, RpcRes, RpcResult, }, @@ -122,4 +123,96 @@ impl MediaApis { _ => Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR)), } } + + /// connect whep endpoint + #[oai(path = "/whep/endpoint", method = "post")] + async fn whep_create( + &self, + Data(data): Data<&MediaServerCtx>, + UserAgent(user_agent): UserAgent, + RemoteIpAddr(ip_addr): RemoteIpAddr, + TokenAuthorization(token): TokenAuthorization, + body: ApplicationSdp, + ) -> Result>> { + log::info!("[MediaAPIs] create whep endpoint with token {}, ip {}, user_agent {}", token.token, ip_addr, user_agent); + let (req, rx) = Rpc::new(RpcReq::Whep(whep::RpcReq::Connect(WhepConnectReq { + ip: ip_addr, + sdp: body.0, + token: token.token, + user_agent, + }))); + data.sender.send(req).await.map_err(|_e| poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR))?; + let res = rx.await.map_err(|_e| poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR))?; + match res { + RpcRes::Whep(whep::RpcRes::Connect(res)) => match res { + RpcResult::Ok(res) => { + log::info!("[HttpApis] Whep endpoint created with conn_id {}", res.conn_id); + Ok(CustomHttpResponse { + code: StatusCode::CREATED, + res: ApplicationSdp(res.sdp), + headers: vec![("location", format!("/whep/conn/{}", res.conn_id))], + }) + } + RpcResult::Err(e) => { + log::warn!("Whep endpoint creation failed with {e}"); + Err(poem::Error::from_string(e.to_string(), StatusCode::BAD_REQUEST)) + } + }, + _ => Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR)), + } + } + + /// patch whep conn for trickle-ice + #[oai(path = "/whep/conn/:conn_id", method = "patch")] + async fn conn_whep_patch(&self, Data(data): Data<&MediaServerCtx>, conn_id: Path, body: ApplicationSdpPatch) -> Result>> { + let conn_id = conn_id.0.parse().map_err(|_e| poem::Error::from_status(StatusCode::BAD_REQUEST))?; + log::info!("[HttpApis] patch whep endpoint with sdp {}", body.0); + let (req, rx) = Rpc::new(RpcReq::Whep(whep::RpcReq::RemoteIce(WhepRemoteIceReq { conn_id, ice: body.0 }))); + data.sender.send(req).await.map_err(|_e| poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR))?; + let res = rx.await.map_err(|_e| poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR))?; + //TODO process with ICE restart + match res { + RpcRes::Whep(whep::RpcRes::RemoteIce(res)) => match res { + RpcResult::Ok(_res) => { + log::info!("[HttpApis] Whep endpoint patch trickle-ice with conn_id {conn_id}"); + Ok(HttpResponse::new(ApplicationSdpPatch("".to_string())).status(StatusCode::NO_CONTENT)) + } + RpcResult::Err(e) => { + log::warn!("Whep endpoint patch trickle-ice failed with error {e}"); + Err(poem::Error::from_string(e.to_string(), StatusCode::BAD_REQUEST)) + } + }, + _ => Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR)), + } + } + + /// post whep conn for action + #[oai(path = "/api/whep/conn/:conn_id", method = "post")] + async fn conn_whep_post(&self, _ctx: Data<&MediaServerCtx>, _conn_id: Path, _body: Json) -> Result> { + // let conn_id = conn_id.0.parse().map_err(|_e| poem::Error::from_status(StatusCode::BAD_REQUEST))?; + Err(poem::Error::from_string("Not supported", StatusCode::BAD_REQUEST)) + } + + /// delete whep conn + #[oai(path = "/whep/conn/:conn_id", method = "delete")] + async fn conn_whep_delete(&self, Data(data): Data<&MediaServerCtx>, conn_id: Path) -> Result> { + let conn_id = conn_id.0.parse().map_err(|_e| poem::Error::from_status(StatusCode::BAD_REQUEST))?; + log::info!("[HttpApis] close whep endpoint conn {}", conn_id); + let (req, rx) = Rpc::new(RpcReq::Whep(whep::RpcReq::Delete(WhepDeleteReq { conn_id }))); + data.sender.send(req).await.map_err(|_e| poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR))?; + let res = rx.await.map_err(|_e| poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR))?; + match res { + RpcRes::Whep(whep::RpcRes::Delete(res)) => match res { + RpcResult::Ok(_res) => { + log::info!("[HttpApis] Whep endpoint closed with conn_id {conn_id}"); + Ok(PlainText("OK".to_string())) + } + RpcResult::Err(e) => { + log::warn!("Whep endpoint close request failed with error {e}"); + Err(poem::Error::from_string(e.to_string(), StatusCode::BAD_REQUEST)) + } + }, + _ => Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR)), + } + } } diff --git a/bin/src/main.rs b/bin/src/main.rs index 44d6339d..73742963 100644 --- a/bin/src/main.rs +++ b/bin/src/main.rs @@ -48,22 +48,26 @@ struct Args { #[arg(env, long)] seeds: Vec, - /// Neighbors - #[arg(env, long)] + /// Workers + #[arg(env, long, default_value_t = 1)] workers: usize, #[command(subcommand)] server: ServerType, } -#[tokio::main] +#[tokio::main(flavor = "current_thread")] async fn main() { if std::env::var_os("RUST_LOG").is_none() { - std::env::set_var("RUST_LOG", "atm0s_media_server=info"); + std::env::set_var("RUST_LOG", "info"); + } + if std::env::var_os("RUST_BACKTRACE").is_none() { + std::env::set_var("RUST_BACKTRACE", "1"); } let args: Args = Args::parse(); tracing_subscriber::registry().with(fmt::layer()).with(EnvFilter::from_default_env()).init(); + let http_port = args.http_port; let workers = args.workers; let node = NodeConfig { node_id: args.node_id, @@ -76,6 +80,6 @@ async fn main() { match args.server { ServerType::Gateway(args) => run_media_gateway(workers, args).await, ServerType::Connector(args) => run_media_connector(workers, args).await, - ServerType::Media(args) => run_media_server(workers, node, args).await, + ServerType::Media(args) => run_media_server(workers, http_port, node, args).await, } } diff --git a/bin/src/server/media.rs b/bin/src/server/media.rs index 71088573..97cc8ca2 100644 --- a/bin/src/server/media.rs +++ b/bin/src/server/media.rs @@ -12,20 +12,25 @@ use runtime_worker::{ExtIn, ExtOut}; #[derive(Debug, Parser)] pub struct Args { + /// Custom binding address for WebRTC UDP + #[arg(env, long)] custom_addrs: Vec, } -pub async fn run_media_server(workers: usize, node: NodeConfig, args: Args) { +pub async fn run_media_server(workers: usize, http_port: Option, node: NodeConfig, args: Args) { println!("Running media server"); let (req_tx, mut req_rx) = tokio::sync::mpsc::channel(1024); - tokio::spawn(async move { - if let Err(e) = run_media_http_server(req_tx).await { - log::error!("HTTP Error: {}", e); - } - }); + if let Some(http_port) = http_port { + tokio::spawn(async move { + if let Err(e) = run_media_http_server(http_port, req_tx).await { + log::error!("HTTP Error: {}", e); + } + }); + } //TODO get local addrs let node_id = node.node_id; + let node_session = node.session; let webrtc_addrs = args.custom_addrs; let mut controller = Controller::<_, _, _, _, _, 128>::default(); for i in 0..workers { @@ -49,13 +54,19 @@ pub async fn run_media_server(workers: usize, node: NodeConfig, args: Args) { req_id_seed += 1; reqs.insert(req_id, req.answer_tx); - let (req, _node_id) = req.req.extract(); - let (req, worker) = req.extract(); + let (req, _node_id) = req.req.down(); + let (req, worker) = req.down(); let ext = ExtIn::Rpc(req_id, req); if let Some(worker) = worker { - controller.send_to(worker, ext); + if worker < workers as u16 { + log::info!("on req {req_id} dest to worker {worker}"); + controller.send_to(worker, ext); + } else { + log::info!("on req {req_id} dest to wrong worker {worker} but workers is {workers}"); + } } else { + log::info!("on req {req_id} dest to any worker"); controller.send_to_best(ext); } } @@ -63,7 +74,8 @@ pub async fn run_media_server(workers: usize, node: NodeConfig, args: Args) { while let Some(out) = controller.pop_event() { match out { ExtOut::Rpc(req_id, worker, res) => { - let res = res.up_layer(worker).up_layer(node_id); + log::info!("on req {req_id} res from worker {worker}"); + let res = res.up(worker).up((node_id, node_session)); if let Some(tx) = reqs.remove(&req_id) { if let Err(_) = tx.send(res) { log::error!("Send rpc response error for req {req_id}"); diff --git a/bin/src/server/media/runtime_worker.rs b/bin/src/server/media/runtime_worker.rs index 22b8bb25..7027ba62 100644 --- a/bin/src/server/media/runtime_worker.rs +++ b/bin/src/server/media/runtime_worker.rs @@ -94,6 +94,9 @@ impl WorkerInner for MediaRunt } fn on_tick<'a>(&mut self, now: Instant) -> Option> { + if !self.queue.is_empty() { + return self.queue.pop_front(); + } let out = self.worker.on_tick(now)?; Some(self.process_out(out)) } @@ -104,6 +107,9 @@ impl WorkerInner for MediaRunt } fn pop_output<'a>(&mut self, now: Instant) -> Option> { + if !self.queue.is_empty() { + return self.queue.pop_front(); + } let out = self.worker.pop_output(now)?; Some(self.process_out(out)) } diff --git a/packages/media_core/Cargo.toml b/packages/media_core/Cargo.toml index 0e05edb0..ee3e3ece 100644 --- a/packages/media_core/Cargo.toml +++ b/packages/media_core/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" [dependencies] log = { workspace = true } num_enum = { workspace = true } +smallmap = { workspace = true } sans-io-runtime = { workspace = true, default-features = false } atm0s-sdn = { workspace = true } media-server-protocol = { path = "../protocol" } diff --git a/packages/media_core/src/cluster.rs b/packages/media_core/src/cluster.rs index 23732a4a..7284e30d 100644 --- a/packages/media_core/src/cluster.rs +++ b/packages/media_core/src/cluster.rs @@ -78,18 +78,22 @@ impl Default for MediaCluster { impl MediaCluster { pub fn on_tick(&mut self, now: Instant) -> Option> { - todo!() + //TODO + None } pub fn on_input(&mut self, now: Instant, input: Input) -> Option> { - todo!() + //TODO + None } pub fn pop_output(&mut self, now: Instant) -> Option> { - todo!() + //TODO + None } pub fn shutdown<'a>(&mut self, now: Instant) -> Option> { - todo!() + //TODO + None } } diff --git a/packages/media_core/src/endpoint.rs b/packages/media_core/src/endpoint.rs index 0670b777..44f0b599 100644 --- a/packages/media_core/src/endpoint.rs +++ b/packages/media_core/src/endpoint.rs @@ -189,6 +189,7 @@ impl, ExtIn, ExtOut> Endpoint { let out = self.internal.on_transport_rpc(now, req_id, req)?; self.process_internal_output(now, out) } + TransportOutput::Destroy => Some(EndpointOutput::Shutdown), } } diff --git a/packages/media_core/src/endpoint/internal.rs b/packages/media_core/src/endpoint/internal.rs index ba0b9b14..e79cbbde 100644 --- a/packages/media_core/src/endpoint/internal.rs +++ b/packages/media_core/src/endpoint/internal.rs @@ -1,4 +1,7 @@ -use std::time::Instant; +use std::{ + collections::{HashMap, VecDeque}, + time::Instant, +}; use media_server_protocol::endpoint::{PeerId, RoomId}; @@ -7,8 +10,13 @@ use crate::{ transport::{LocalTrackEvent, LocalTrackId, RemoteTrackEvent, RemoteTrackId, TransportEvent, TransportInput, TransportState, TransportStats}, }; +use self::{local_track::EndpointLocalTrack, remote_track::EndpointRemoteTrack}; + use super::{middleware::EndpointMiddleware, EndpointEvent, EndpointReq, EndpointReqId, EndpointRes}; +mod local_track; +mod remote_track; + pub enum InternalOutput { Event(EndpointEvent), RpcRes(EndpointReqId, EndpointRes), @@ -16,13 +24,28 @@ pub enum InternalOutput { } pub struct EndpointInternal { + state: TransportState, room: Option<(RoomId, PeerId)>, + local_tracks_id: Vec, + remote_tracks_id: Vec, + local_tracks: smallmap::Map, + remote_tracks: smallmap::Map, middlewares: Vec>, + queue: VecDeque, } impl EndpointInternal { pub fn new() -> Self { - Self { room: None, middlewares: Vec::new() } + Self { + state: TransportState::Connecting, + room: None, + local_tracks_id: Default::default(), + remote_tracks_id: Default::default(), + local_tracks: Default::default(), + remote_tracks: Default::default(), + middlewares: Default::default(), + queue: Default::default(), + } } pub fn on_tick<'a>(&mut self, now: Instant) -> Option { @@ -30,7 +53,7 @@ impl EndpointInternal { } pub fn pop_output<'a>(&mut self, now: Instant) -> Option { - None + self.queue.pop_front() } } @@ -46,55 +69,80 @@ impl EndpointInternal { } pub fn on_transport_rpc<'a>(&mut self, now: Instant, req_id: EndpointReqId, req: EndpointReq) -> Option { - todo!() + match req { + EndpointReq::JoinRoom(room, peer) => { + self.room = Some((room.clone(), peer.clone())); + if matches!(self.state, TransportState::Connecting) { + Some(InternalOutput::Cluster(ClusterEndpointControl::JoinRoom(room, peer))) + } else { + None + } + } + EndpointReq::LeaveRoom => { + self.room.take()?; + if matches!(self.state, TransportState::Connecting) { + Some(InternalOutput::Cluster(ClusterEndpointControl::LeaveRoom)) + } else { + None + } + } + EndpointReq::RemoteTrack(track, control) => todo!(), + EndpointReq::LocalTrack(_, _) => todo!(), + } } fn on_transport_state_changed<'a>(&mut self, now: Instant, state: TransportState) -> Option { - match state { - TransportState::Connecting => todo!(), - TransportState::ConnectError(_) => todo!(), - TransportState::Connected => todo!(), - TransportState::Reconnecting => todo!(), - TransportState::Disconnected(_) => todo!(), + self.state = state; + match &self.state { + TransportState::Connecting => None, + TransportState::ConnectError(_) => None, + TransportState::Connected => { + for i in 0..self.local_tracks_id.len() { + let id = self.local_tracks_id[i]; + if let Some(out) = self.local_tracks.get_mut(&id).expect("Should have").on_connected(now) { + if let Some(out) = self.on_local_track_output(now, id, out) { + self.queue.push_back(out); + } + } + } + for i in 0..self.remote_tracks_id.len() { + let id = self.remote_tracks_id[i]; + let track = self.remote_tracks.get_mut(&id).expect("Should have"); + if let Some(out) = track.on_connected(now) { + if let Some(out) = self.on_remote_track_output(now, id, out) { + self.queue.push_back(out); + } + } + } + let (room, peer) = self.room.as_ref()?; + self.queue.push_back(InternalOutput::Cluster(ClusterEndpointControl::JoinRoom(room.clone(), peer.clone()))); + self.queue.pop_front() + } + TransportState::Reconnecting => None, + TransportState::Disconnected(_) => None, } } fn on_transport_remote_track<'a>(&mut self, now: Instant, track: RemoteTrackId, event: RemoteTrackEvent) -> Option { - match event { - RemoteTrackEvent::Started { name } => todo!(), - RemoteTrackEvent::Paused => todo!(), - RemoteTrackEvent::Resumed => todo!(), - RemoteTrackEvent::Media(_) => todo!(), - RemoteTrackEvent::Ended => todo!(), + if event.need_create() { + self.remote_tracks_id.push(track); + self.remote_tracks.insert(track, EndpointRemoteTrack::default()); } + let out = self.remote_tracks.get_mut(&track)?.on_event(now, event)?; + self.on_remote_track_output(now, track, out) } fn on_transport_local_track<'a>(&mut self, now: Instant, track: LocalTrackId, event: LocalTrackEvent) -> Option { - match event { - LocalTrackEvent::Started => todo!(), - LocalTrackEvent::RequestKeyFrame => todo!(), - LocalTrackEvent::Switch(_) => todo!(), - LocalTrackEvent::Ended => todo!(), - } - } - - fn on_transport_req<'a>(&mut self, now: Instant, req_id: EndpointReqId, req: EndpointReq) -> Option { - match req { - EndpointReq::JoinRoom(room, peer) => { - self.room = Some((room.clone(), peer.clone())); - Some(InternalOutput::Cluster(ClusterEndpointControl::JoinRoom(room, peer))) - } - EndpointReq::LeaveRoom => { - self.room.take()?; - Some(InternalOutput::Cluster(ClusterEndpointControl::LeaveRoom)) - } - EndpointReq::RemoteTrack(track, control) => todo!(), - EndpointReq::LocalTrack(_, _) => todo!(), + if event.need_create() { + self.local_tracks_id.push(track); + self.local_tracks.insert(track, EndpointLocalTrack::default()); } + let out = self.local_tracks.get_mut(&track)?.on_event(now, event)?; + self.on_local_track_output(now, track, out) } fn on_transport_stats<'a>(&mut self, now: Instant, stats: TransportStats) -> Option { - todo!() + None } } @@ -123,3 +171,14 @@ impl EndpointInternal { } } } + +/// This block for internal local and remote track +impl EndpointInternal { + fn on_remote_track_output<'a>(&mut self, now: Instant, id: RemoteTrackId, out: remote_track::Output) -> Option { + todo!() + } + + fn on_local_track_output<'a>(&mut self, now: Instant, id: LocalTrackId, out: local_track::Output) -> Option { + todo!() + } +} diff --git a/packages/media_core/src/endpoint/internal/local_track.rs b/packages/media_core/src/endpoint/internal/local_track.rs new file mode 100644 index 00000000..4f4bed99 --- /dev/null +++ b/packages/media_core/src/endpoint/internal/local_track.rs @@ -0,0 +1,20 @@ +use std::time::Instant; + +use crate::transport::{LocalTrackEvent, LocalTrackId}; + +pub enum Output {} + +#[derive(Default)] +pub struct EndpointLocalTrack {} + +impl EndpointLocalTrack { + pub fn on_connected(&mut self, now: Instant) -> Option { + None + } + pub fn on_event(&mut self, now: Instant, event: LocalTrackEvent) -> Option { + None + } + pub fn pop_output(&mut self) -> Option { + None + } +} diff --git a/packages/media_core/src/endpoint/internal/remote_track.rs b/packages/media_core/src/endpoint/internal/remote_track.rs new file mode 100644 index 00000000..701d2801 --- /dev/null +++ b/packages/media_core/src/endpoint/internal/remote_track.rs @@ -0,0 +1,20 @@ +use std::time::Instant; + +use crate::transport::{RemoteTrackEvent, RemoteTrackId}; + +pub enum Output {} + +#[derive(Default)] +pub struct EndpointRemoteTrack {} + +impl EndpointRemoteTrack { + pub fn on_connected(&mut self, now: Instant) -> Option { + None + } + pub fn on_event(&mut self, now: Instant, event: RemoteTrackEvent) -> Option { + None + } + pub fn pop_output(&mut self) -> Option { + None + } +} diff --git a/packages/media_core/src/transport.rs b/packages/media_core/src/transport.rs index db507e87..e494d4df 100644 --- a/packages/media_core/src/transport.rs +++ b/packages/media_core/src/transport.rs @@ -1,4 +1,4 @@ -use std::time::Instant; +use std::{hash::Hash, time::Instant}; use media_server_protocol::{ endpoint::{PeerId, TrackName}, @@ -16,10 +16,22 @@ pub struct TransportId(pub u64); #[derive(Clone, Copy, PartialEq, Eq)] pub struct RemoteTrackId(pub u16); +impl Hash for RemoteTrackId { + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + /// LocalTrackId is used for track which send media to client #[derive(Clone, Copy, PartialEq, Eq)] pub struct LocalTrackId(pub u16); +impl Hash for LocalTrackId { + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + pub enum TransportError { Timeout, } @@ -47,6 +59,12 @@ pub enum LocalTrackEvent { Ended, } +impl LocalTrackEvent { + pub fn need_create(&self) -> bool { + matches!(self, LocalTrackEvent::Started { .. }) + } +} + /// This is used for notifying state of remote track to endpoint pub enum RemoteTrackEvent { Started { name: String }, @@ -56,6 +74,12 @@ pub enum RemoteTrackEvent { Ended, } +impl RemoteTrackEvent { + pub fn need_create(&self) -> bool { + matches!(self, RemoteTrackEvent::Started { .. }) + } +} + pub enum TransportEvent { State(TransportState), RemoteTrack(RemoteTrackId, RemoteTrackEvent), @@ -78,6 +102,7 @@ pub enum TransportOutput<'a, Ext> { Event(TransportEvent), RpcReq(EndpointReqId, EndpointReq), Ext(Ext), + Destroy, } pub trait Transport { diff --git a/packages/media_runner/src/worker.rs b/packages/media_runner/src/worker.rs index 4eeb97c4..dac4b866 100644 --- a/packages/media_runner/src/worker.rs +++ b/packages/media_runner/src/worker.rs @@ -3,6 +3,7 @@ use std::{collections::VecDeque, net::SocketAddr, time::Instant}; use atm0s_sdn::{services::visualization, NetInput, NetOutput, SdnExtIn, SdnExtOut, SdnWorker, SdnWorkerBusEvent, SdnWorkerCfg, SdnWorkerInput, SdnWorkerOutput, TimePivot}; use media_server_core::cluster::{self, MediaCluster}; use media_server_protocol::transport::{ + whep::{self, WhepConnectRes, WhepDeleteRes, WhepRemoteIceRes}, whip::{self, WhipConnectRes, WhipDeleteRes, WhipRemoteIceRes}, RpcReq, RpcRes, }; @@ -154,17 +155,17 @@ impl MediaServerWorker { match c.try_into().ok()? { TaskType::Sdn => { let now_ms = self.timer.timestamp_ms(now); - if let Some(out) = s.looper_process(self.sdn_worker.pop_output(now_ms)) { + if let Some(out) = s.queue_process(self.sdn_worker.pop_output(now_ms)) { return Some(self.output_sdn(now, out)); } } TaskType::MediaCluster => { - if let Some(out) = s.looper_process(self.media_cluster.pop_output(now)) { + if let Some(out) = s.queue_process(self.media_cluster.pop_output(now)) { return Some(self.output_cluster(now, out)); } } TaskType::MediaWebrtc => { - if let Some(out) = s.looper_process(self.media_webrtc.pop_output(now)) { + if let Some(out) = s.queue_process(self.media_webrtc.pop_output(now)) { return Some(self.output_webrtc(now, out)); } } @@ -268,7 +269,7 @@ impl MediaServerWorker { transport_webrtc::GroupOutput::Ext(_owner, ext) => match ext { transport_webrtc::ExtOut::RemoteIce(req_id, variant, res) => match variant { transport_webrtc::Variant::Whip => Output::ExtRpc(req_id, RpcRes::Whip(whip::RpcRes::RemoteIce(res.map(|_| WhipRemoteIceRes {})))), - transport_webrtc::Variant::Whep => todo!(), + transport_webrtc::Variant::Whep => Output::ExtRpc(req_id, RpcRes::Whep(whep::RpcRes::RemoteIce(res.map(|_| WhepRemoteIceRes {})))), transport_webrtc::Variant::Sdk => { todo!() } @@ -280,6 +281,7 @@ impl MediaServerWorker { impl MediaServerWorker { fn process_rpc<'a>(&mut self, now: Instant, req_id: u64, req: RpcReq) -> Option> { + log::info!("[MediaServerWorker] incoming rpc req {req_id}"); match req { RpcReq::Whip(req) => match req { whip::RpcReq::Connect(req) => match self.media_webrtc.spawn(transport_webrtc::Variant::Whip, &req.sdp) { @@ -287,6 +289,7 @@ impl MediaServerWorker { Err(e) => Some(Output::ExtRpc(req_id, RpcRes::Whip(whip::RpcRes::Connect(Err(e))))), }, whip::RpcReq::RemoteIce(req) => { + log::info!("on rpc request {req_id}, whip::RpcReq::RemoteIce"); let out = self.media_webrtc.on_event( now, GroupInput::Ext(req.conn_id.into(), transport_webrtc::ExtIn::RemoteIce(req_id, transport_webrtc::Variant::Whip, req.ice)), @@ -300,6 +303,26 @@ impl MediaServerWorker { Some(self.output_webrtc(now, out)) } }, + RpcReq::Whep(req) => match req { + whep::RpcReq::Connect(req) => match self.media_webrtc.spawn(transport_webrtc::Variant::Whep, &req.sdp) { + Ok((sdp, conn_id)) => Some(Output::ExtRpc(req_id, RpcRes::Whep(whep::RpcRes::Connect(Ok(WhepConnectRes { conn_id, sdp }))))), + Err(e) => Some(Output::ExtRpc(req_id, RpcRes::Whep(whep::RpcRes::Connect(Err(e))))), + }, + whep::RpcReq::RemoteIce(req) => { + log::info!("on rpc request {req_id}, whep::RpcReq::RemoteIce"); + let out = self.media_webrtc.on_event( + now, + GroupInput::Ext(req.conn_id.into(), transport_webrtc::ExtIn::RemoteIce(req_id, transport_webrtc::Variant::Whep, req.ice)), + )?; + Some(self.output_webrtc(now, out)) + } + whep::RpcReq::Delete(req) => { + //TODO check error instead of auto response ok + self.queue.push_back(Output::ExtRpc(req_id, RpcRes::Whep(whep::RpcRes::Delete(Ok(WhepDeleteRes {}))))); + let out = self.media_webrtc.on_event(now, GroupInput::Close(req.conn_id.into()))?; + Some(self.output_webrtc(now, out)) + } + }, } } } diff --git a/packages/media_utils/Cargo.toml b/packages/media_utils/Cargo.toml index 5f3a0323..93dfcc79 100644 --- a/packages/media_utils/Cargo.toml +++ b/packages/media_utils/Cargo.toml @@ -6,4 +6,4 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -smallmap = "1.4.2" +smallmap = { workspace = true } diff --git a/packages/protocol/src/endpoint.rs b/packages/protocol/src/endpoint.rs index fc891689..6818fd56 100644 --- a/packages/protocol/src/endpoint.rs +++ b/packages/protocol/src/endpoint.rs @@ -1,8 +1,11 @@ use std::{fmt::Display, str::FromStr}; -use crate::media::{MediaCodec, MediaKind, MediaScaling}; +use crate::{ + media::{MediaCodec, MediaKind, MediaScaling}, + transport::ConnLayer, +}; -#[derive(Clone, Copy)] +#[derive(Clone, Copy, PartialEq, Eq, Debug)] pub struct ClusterConnId { pub node: u32, pub node_session: u64, @@ -10,27 +13,90 @@ pub struct ClusterConnId { } impl FromStr for ClusterConnId { - type Err = String; + type Err = &'static str; fn from_str(s: &str) -> Result { - todo!() + let parts = s.split('-').collect::>(); + let node = parts.get(0).ok_or("MISSING NODE_ID")?.parse::().map_err(|_| "PARSE ERROR NODE_ID")?; + let node_session = parts.get(1).ok_or("MISSING NODE_SESSION")?.parse::().map_err(|_| "PARSE ERROR NODE_SESSION")?; + let server_conn = parts.get(2).ok_or("MISSING SERVER_CONN")?.parse::().map_err(|_| "PARSE ERROR SERVER_CONN")?; + Ok(Self { node, node_session, server_conn }) } } impl Display for ClusterConnId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "({},{},{})", self.node, self.node_session, self.server_conn) + write!(f, "{}-{}-{}", self.node, self.node_session, self.server_conn) } } -#[derive(Clone, Copy)] +impl ConnLayer for ClusterConnId { + type Up = (); + type UpParam = (); + type Down = ServerConnId; + type DownRes = (u32, u64); + + fn down(self) -> (Self::Down, Self::DownRes) { + (self.server_conn, (self.node, self.node_session)) + } + + fn up(self, _param: Self::UpParam) -> Self::Up { + panic!("should not happen") + } +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] pub struct ServerConnId { pub worker: u16, pub index: usize, } +impl FromStr for ServerConnId { + type Err = &'static str; + fn from_str(s: &str) -> Result { + let parts = s.split(',').collect::>(); + let worker = parts.get(0).ok_or("MISSING WORKER")?.parse::().map_err(|_| "PARSE ERROR WORKER")?; + let index = parts.get(1).ok_or("MISSING INDEX")?.parse::().map_err(|_| "PARSE ERROR INDEX")?; + Ok(Self { worker, index }) + } +} + impl Display for ServerConnId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "({},{})", self.worker, self.index) + write!(f, "{},{}", self.worker, self.index) + } +} + +impl ConnLayer for ServerConnId { + type Up = ClusterConnId; + type UpParam = (u32, u64); + type Down = usize; + type DownRes = u16; + + fn down(self) -> (Self::Down, Self::DownRes) { + (self.index, self.worker) + } + + fn up(self, param: Self::UpParam) -> Self::Up { + ClusterConnId { + node: param.0, + node_session: param.1, + server_conn: self, + } + } +} + +impl ConnLayer for usize { + type Up = ServerConnId; + type UpParam = u16; + type Down = (); + type DownRes = (); + + fn down(self) -> (Self::Down, Self::DownRes) { + panic!("should not happen") + } + + fn up(self, param: Self::UpParam) -> Self::Up { + ServerConnId { index: self, worker: param } } } @@ -49,3 +115,35 @@ pub struct TrackMeta { pub codec: MediaCodec, pub scaling: MediaScaling, } + +#[cfg(test)] +mod test { + use std::str::FromStr; + + use super::{ClusterConnId, ServerConnId}; + + #[test] + fn server_conn_id_parse() { + let conn = ServerConnId { worker: 1, index: 2 }; + assert_eq!(conn.to_string(), "1,2"); + assert_eq!(ServerConnId::from_str("1,2"), Ok(ServerConnId { worker: 1, index: 2 })); + } + + #[test] + fn cluster_conn_id_pase() { + let conn = ClusterConnId { + node: 1, + node_session: 2, + server_conn: ServerConnId { worker: 3, index: 4 }, + }; + assert_eq!(conn.to_string(), "1-2-3,4"); + assert_eq!( + ClusterConnId::from_str("1-2-3,4"), + Ok(ClusterConnId { + node: 1, + node_session: 2, + server_conn: ServerConnId { worker: 3, index: 4 }, + }) + ); + } +} diff --git a/packages/protocol/src/transport.rs b/packages/protocol/src/transport.rs index 6a22650c..74497484 100644 --- a/packages/protocol/src/transport.rs +++ b/packages/protocol/src/transport.rs @@ -1,46 +1,52 @@ use std::fmt::Display; -use crate::endpoint::{ClusterConnId, ServerConnId}; - pub mod webrtc; pub mod whep; pub mod whip; +pub trait ConnLayer { + type Up; + type UpParam; + type Down; + type DownRes; + + fn down(self) -> (Self::Down, Self::DownRes); + fn up(self, param: Self::UpParam) -> Self::Up; +} + #[derive(Debug, Clone, convert_enum::From, convert_enum::TryInto)] pub enum RpcReq { - // Webrtc(webrtc::RpcReq), - // Whep(whep::RpcReq), + Whep(whep::RpcReq), Whip(whip::RpcReq), } -impl RpcReq { - pub fn extract(self) -> (RpcReq, Option) { - todo!() - } -} - -impl RpcReq { - pub fn extract(self) -> (RpcReq, Option) { - todo!() +impl RpcReq { + pub fn down(self) -> (RpcReq, Option) { + match self { + Self::Whip(req) => { + let (req, layer) = req.down(); + (RpcReq::Whip(req), layer) + } + Self::Whep(req) => { + let (req, layer) = req.down(); + (RpcReq::Whep(req), layer) + } + } } } #[derive(Debug, Clone, convert_enum::From, convert_enum::TryInto)] pub enum RpcRes { - // Webrtc(webrtc::RpcRes), - // Whep(whep::RpcRes), + Whep(whep::RpcRes), Whip(whip::RpcRes), } -impl RpcRes { - pub fn up_layer(self, node: u32) -> RpcRes { - todo!() - } -} - -impl RpcRes { - pub fn up_layer(self, worker: u16) -> RpcRes { - todo!() +impl RpcRes { + pub fn up(self, param: Conn::UpParam) -> RpcRes { + match self { + Self::Whip(req) => RpcRes::Whip(req.up(param)), + Self::Whep(req) => RpcRes::Whep(req.up(param)), + } } } diff --git a/packages/protocol/src/transport/whep.rs b/packages/protocol/src/transport/whep.rs index e7c950a2..19d2c9e1 100644 --- a/packages/protocol/src/transport/whep.rs +++ b/packages/protocol/src/transport/whep.rs @@ -1,5 +1,78 @@ +use std::net::IpAddr; + +use super::{ConnLayer, RpcResult}; + +#[derive(Debug, Clone)] +pub struct WhepConnectReq { + pub sdp: String, + pub token: String, + pub ip: IpAddr, + pub user_agent: String, +} + +#[derive(Debug, Clone)] +pub struct WhepConnectRes { + pub conn_id: Conn, + pub sdp: String, +} + +#[derive(Debug, Clone)] +pub struct WhepRemoteIceReq { + pub conn_id: Conn, + pub ice: String, +} + +#[derive(Debug, Clone)] +pub struct WhepRemoteIceRes {} + #[derive(Debug, Clone)] -pub enum RpcReq {} +pub struct WhepDeleteReq { + pub conn_id: Conn, +} #[derive(Debug, Clone)] -pub enum RpcRes {} +pub struct WhepDeleteRes {} + +#[derive(Debug, Clone, convert_enum::From, convert_enum::TryInto)] +pub enum RpcReq { + Connect(WhepConnectReq), + RemoteIce(WhepRemoteIceReq), + Delete(WhepDeleteReq), +} + +impl RpcReq { + pub fn down(self) -> (RpcReq, Option) { + match self { + RpcReq::Connect(req) => (RpcReq::Connect(req), None), + RpcReq::RemoteIce(req) => { + let (down, layer) = req.conn_id.down(); + (RpcReq::RemoteIce(WhepRemoteIceReq { conn_id: down, ice: req.ice }), Some(layer)) + } + RpcReq::Delete(req) => { + let (down, layer) = req.conn_id.down(); + (RpcReq::Delete(WhepDeleteReq { conn_id: down }), Some(layer)) + } + } + } +} + +#[derive(Debug, Clone, convert_enum::From, convert_enum::TryInto)] +pub enum RpcRes { + Connect(RpcResult>), + RemoteIce(RpcResult), + Delete(RpcResult), +} + +impl RpcRes { + pub fn up(self, param: Conn::UpParam) -> RpcRes { + match self { + RpcRes::Connect(Ok(res)) => RpcRes::Connect(Ok(WhepConnectRes { + conn_id: res.conn_id.up(param), + sdp: res.sdp, + })), + RpcRes::Connect(Err(e)) => RpcRes::Connect(Err(e)), + RpcRes::RemoteIce(res) => RpcRes::RemoteIce(res), + RpcRes::Delete(res) => RpcRes::Delete(res), + } + } +} diff --git a/packages/protocol/src/transport/whip.rs b/packages/protocol/src/transport/whip.rs index 83e68122..9e49dd48 100644 --- a/packages/protocol/src/transport/whip.rs +++ b/packages/protocol/src/transport/whip.rs @@ -1,6 +1,6 @@ use std::net::IpAddr; -use super::RpcResult; +use super::{ConnLayer, RpcResult}; #[derive(Debug, Clone)] pub struct WhipConnectReq { @@ -40,9 +40,39 @@ pub enum RpcReq { Delete(WhipDeleteReq), } +impl RpcReq { + pub fn down(self) -> (RpcReq, Option) { + match self { + RpcReq::Connect(req) => (RpcReq::Connect(req), None), + RpcReq::RemoteIce(req) => { + let (down, layer) = req.conn_id.down(); + (RpcReq::RemoteIce(WhipRemoteIceReq { conn_id: down, ice: req.ice }), Some(layer)) + } + RpcReq::Delete(req) => { + let (down, layer) = req.conn_id.down(); + (RpcReq::Delete(WhipDeleteReq { conn_id: down }), Some(layer)) + } + } + } +} + #[derive(Debug, Clone, convert_enum::From, convert_enum::TryInto)] pub enum RpcRes { Connect(RpcResult>), RemoteIce(RpcResult), Delete(RpcResult), } + +impl RpcRes { + pub fn up(self, param: Conn::UpParam) -> RpcRes { + match self { + RpcRes::Connect(Ok(res)) => RpcRes::Connect(Ok(WhipConnectRes { + conn_id: res.conn_id.up(param), + sdp: res.sdp, + })), + RpcRes::Connect(Err(e)) => RpcRes::Connect(Err(e)), + RpcRes::RemoteIce(res) => RpcRes::RemoteIce(res), + RpcRes::Delete(res) => RpcRes::Delete(res), + } + } +} diff --git a/packages/transport_webrtc/src/transport.rs b/packages/transport_webrtc/src/transport.rs index f41d5bb5..37c4b34d 100644 --- a/packages/transport_webrtc/src/transport.rs +++ b/packages/transport_webrtc/src/transport.rs @@ -1,11 +1,15 @@ use std::{net::SocketAddr, ops::Deref, time::Instant}; -use media_server_core::transport::{Transport, TransportInput, TransportOutput}; +use media_server_core::{ + endpoint::{EndpointEvent, EndpointReqId, EndpointRes}, + transport::{Transport, TransportInput, TransportOutput}, +}; use media_server_protocol::{ media::MediaPacket, transport::{RpcError, RpcResult}, }; -use sans_io_runtime::Buffer; +use media_server_utils::Small2dMap; +use sans_io_runtime::backend::{BackendIncoming, BackendOutgoing}; use str0m::{ bwe::Bitrate, change::{DtlsCert, SdpOffer}, @@ -36,23 +40,27 @@ pub enum ExtOut { } enum InternalOutput<'a> { - Str0mReceive(Instant, Protocol, SocketAddr, SocketAddr, Buffer<'a>), - Str0mTick(Instant), Str0mKeyframe(Mid, KeyframeRequestKind), Str0mLimitBitrate(Mid, u64), Str0mSendMedia(Mid, MediaPacket), TransportOutput(TransportOutput<'a, ExtOut>), + Destroy, } trait TransportWebrtcInternal { fn on_tick<'a>(&mut self, now: Instant) -> Option>; - fn on_transport_input<'a>(&mut self, now: Instant, input: TransportInput<'a, ExtIn>) -> Option>; - fn on_str0m_out<'a>(&mut self, now: Instant, out: str0m::Output) -> Option>; + fn on_transport_rpc_res<'a>(&mut self, now: Instant, req_id: EndpointReqId, res: EndpointRes) -> Option>; + fn on_endpoint_event<'a>(&mut self, now: Instant, input: EndpointEvent) -> Option>; + fn on_str0m_event<'a>(&mut self, now: Instant, event: str0m::Event) -> Option>; + fn close<'a>(&mut self, now: Instant) -> Option>; + fn pop_output<'a>(&mut self, now: Instant) -> Option>; } pub struct TransportWebrtc { + next_tick: Option, rtc: Rtc, internal: Box, + ports: Small2dMap, } impl TransportWebrtc { @@ -64,39 +72,31 @@ impl TransportWebrtc { let mut rtc = rtc_config.build(); rtc.direct_api().enable_twcc_feedback(); - for (local_addr, _slot) in &local_addrs { - rtc.add_local_candidate(Candidate::host(*local_addr, Protocol::Udp).expect("Should add local candidate")); + let mut ports = Small2dMap::default(); + for (local_addr, slot) in local_addrs { + ports.insert(local_addr, slot); + rtc.add_local_candidate(Candidate::host(local_addr, Protocol::Udp).expect("Should add local candidate")); } let answer = rtc.sdp_api().accept_offer(offer).map_err(|_e| RpcError::new2(WebrtcError::Str0mError))?; + Ok(( Self { + next_tick: None, rtc, internal: match variant { - Variant::Whip => Box::new(whip::TransportWebrtcWhip::new(local_addrs)), - Variant::Whep => Box::new(whep::TransportWebrtcWhep::new(local_addrs)), + Variant::Whip => Box::new(whip::TransportWebrtcWhip::new()), + Variant::Whep => Box::new(whep::TransportWebrtcWhep::new()), Variant::Sdk => unimplemented!(), }, + ports, }, ice_ufrag, answer.to_sdp_string(), )) } - pub fn on_remote_ice<'a>(&mut self, now: Instant, ice: String) -> Option> { - //TODO - self.pop_event(now) - } - fn process_internal_output<'a>(&mut self, now: Instant, out: InternalOutput<'a>) -> Option> { match out { - InternalOutput::Str0mReceive(now, protocol, source, destination, buf) => { - self.rtc.handle_input(str0m::Input::Receive(now, Receive::new(protocol, source, destination, buf.deref()).ok()?)).ok()?; - self.pop_event(now) - } - InternalOutput::Str0mTick(now) => { - self.rtc.handle_input(str0m::Input::Timeout(now)).ok()?; - self.pop_event(now) - } InternalOutput::Str0mKeyframe(mid, kind) => { self.rtc.direct_api().stream_rx_by_mid(mid, None)?.request_keyframe(kind); self.pop_event(now) @@ -114,28 +114,91 @@ impl TransportWebrtc { self.pop_event(now) } InternalOutput::TransportOutput(out) => Some(out), + InternalOutput::Destroy => Some(TransportOutput::Destroy), } } } impl Transport for TransportWebrtc { fn on_tick<'a>(&mut self, now: Instant) -> Option> { + if let Some(next_tick) = self.next_tick { + if next_tick <= now { + self.next_tick = None; + self.rtc.handle_input(str0m::Input::Timeout(now)).ok()?; + return self.pop_event(now); + } + } + let out = self.internal.on_tick(now)?; self.process_internal_output(now, out) } fn on_input<'a>(&mut self, now: Instant, input: TransportInput<'a, ExtIn>) -> Option> { - let out = self.internal.on_transport_input(now, input)?; - self.process_internal_output(now, out) + match input { + TransportInput::Net(net) => match net { + BackendIncoming::UdpPacket { slot, from, data } => { + let destination = *self.ports.get2(&slot)?; + log::trace!("[TransportWebrtc] recv udp from {} to {}, len {}", from, destination, data.len()); + self.rtc + .handle_input(str0m::Input::Receive(now, Receive::new(Protocol::Udp, from, destination, data.deref()).ok()?)) + .ok()?; + self.pop_event(now) + } + _ => panic!("Unexpected input"), + }, + TransportInput::Endpoint(event) => { + let out = self.internal.on_endpoint_event(now, event)?; + self.process_internal_output(now, out) + } + TransportInput::RpcRes(req_id, res) => { + let out = self.internal.on_transport_rpc_res(now, req_id, res)?; + self.process_internal_output(now, out) + } + TransportInput::Ext(ext) => match ext { + ExtIn::RemoteIce(req_id, variant, _ice) => { + //TODO + Some(TransportOutput::Ext(ExtOut::RemoteIce(req_id, variant, Ok(())))) + } + }, + TransportInput::Close => { + self.internal.close(now); + self.rtc.disconnect(); + self.pop_event(now) + } + } } fn pop_event<'a>(&mut self, now: Instant) -> Option> { + while let Some(out) = self.internal.pop_output(now) { + let out = self.process_internal_output(now, out); + if out.is_some() { + return out; + } + } + loop { let out = self.rtc.poll_output().ok()?; - if let Some(out) = self.internal.on_str0m_out(now, out) { - let out = self.process_internal_output(now, out); - if out.is_some() { - return out; + match out { + str0m::Output::Timeout(tick) => { + self.next_tick = Some(tick); + return None; + } + str0m::Output::Transmit(out) => { + log::trace!("[TransportWebrtc] send udp from {} to {}, len {}", out.source, out.destination, out.contents.len()); + let from = self.ports.get1(&out.source)?; + return Some(TransportOutput::Net(BackendOutgoing::UdpPacket { + slot: *from, + to: out.destination, + data: out.contents.to_vec().into(), + })); + } + str0m::Output::Event(e) => { + if let Some(out) = self.internal.on_str0m_event(now, e) { + let out = self.process_internal_output(now, out); + if out.is_some() { + return out; + } + } } } } diff --git a/packages/transport_webrtc/src/transport/whep.rs b/packages/transport_webrtc/src/transport/whep.rs index 89582682..78639f7a 100644 --- a/packages/transport_webrtc/src/transport/whep.rs +++ b/packages/transport_webrtc/src/transport/whep.rs @@ -1,27 +1,25 @@ use std::{ - net::SocketAddr, + collections::VecDeque, time::{Duration, Instant}, }; use media_server_core::{ endpoint::{EndpointEvent, EndpointLocalTrackEvent}, - transport::{LocalTrackEvent, LocalTrackId, TransportError, TransportEvent, TransportInput, TransportOutput, TransportState}, + transport::{LocalTrackEvent, LocalTrackId, TransportError, TransportEvent, TransportOutput, TransportState}, }; use media_server_protocol::endpoint::{PeerId, TrackMeta, TrackName}; -use media_server_utils::Small2dMap; -use sans_io_runtime::backend::{BackendIncoming, BackendOutgoing}; use str0m::{ media::{Direction, MediaAdded, MediaKind, Mid}, - net::Protocol, - Event as Str0mEvent, IceConnectionState, Output as Str0mOutput, + Event as Str0mEvent, IceConnectionState, }; -use super::{ExtIn, InternalOutput, TransportWebrtcInternal}; +use super::{InternalOutput, TransportWebrtcInternal}; const TIMEOUT_SEC: u64 = 10; const AUDIO_TRACK: LocalTrackId = LocalTrackId(0); const VIDEO_TRACK: LocalTrackId = LocalTrackId(1); +#[derive(Debug)] enum State { New, Connecting { at: Instant }, @@ -31,6 +29,7 @@ enum State { Disconnected(Option), } +#[derive(Debug)] enum TransportWebrtcError { Timeout, } @@ -43,27 +42,21 @@ struct SubscribeStreams { } pub struct TransportWebrtcWhep { - next_tick: Option, state: State, - ports: Small2dMap, audio_mid: Option, video_mid: Option, subscribed: SubscribeStreams, + queue: VecDeque>, } impl TransportWebrtcWhep { - pub fn new(local_addrs: Vec<(SocketAddr, usize)>) -> Self { - let mut ports = Small2dMap::default(); - for (local_addr, slot) in local_addrs { - ports.insert(local_addr, slot); - } + pub fn new() -> Self { Self { state: State::New, - next_tick: None, - ports, audio_mid: None, video_mid: None, subscribed: Default::default(), + queue: VecDeque::new(), } } } @@ -77,7 +70,7 @@ impl TransportWebrtcInternal for TransportWebrtcWhep { } State::Connecting { at } => { if now - *at >= Duration::from_secs(TIMEOUT_SEC) { - log::info!("Connect timed out after {:?}", now - *at); + log::info!("[TransportWebrtcWhep] connect timed out after {:?} => switched to ConnectError", now - *at); self.state = State::ConnectError(TransportWebrtcError::Timeout); return Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::ConnectError( TransportError::Timeout, @@ -86,7 +79,7 @@ impl TransportWebrtcInternal for TransportWebrtcWhep { } State::Reconnecting { at } => { if now - *at >= Duration::from_secs(TIMEOUT_SEC) { - log::info!("Reconnecting timed out after {:?}", now - *at); + log::info!("[TransportWebrtcWhep] reconnect timed out after {:?} => switched to Disconnected", now - *at); self.state = State::Disconnected(Some(TransportWebrtcError::Timeout)); return Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Disconnected(Some( TransportError::Timeout, @@ -95,83 +88,88 @@ impl TransportWebrtcInternal for TransportWebrtcWhep { } _ => {} } - let next_tick = self.next_tick?; - if next_tick > now { - return None; - } - self.next_tick = None; - Some(InternalOutput::Str0mTick(now)) + None } - fn on_transport_input<'a>(&mut self, now: Instant, input: TransportInput<'a, ExtIn>) -> Option> { - match input { - TransportInput::Net(net) => match net { - BackendIncoming::UdpPacket { slot, from, data } => { - let destination = self.ports.get2(&slot)?; - Some(InternalOutput::Str0mReceive(now, Protocol::Udp, from, *destination, data.freeze())) - } - _ => panic!("Unexpected input"), + fn on_endpoint_event<'a>(&mut self, now: Instant, event: EndpointEvent) -> Option> { + match event { + EndpointEvent::PeerJoined(_) => None, + EndpointEvent::PeerLeaved(_) => None, + EndpointEvent::PeerTrackStarted(peer, track, meta) => self.try_subscribe(peer, track, meta), + EndpointEvent::PeerTrackStopped(peer, track) => self.try_unsubscribe(peer, track), + EndpointEvent::LocalMediaTrack(track, event) => match event { + EndpointLocalTrackEvent::Media(pkt) => Some(InternalOutput::Str0mSendMedia(self.video_mid?, pkt)), }, - TransportInput::Endpoint(event) => self.on_endpoint_event(now, event), - TransportInput::Ext(_) => panic!("Unexpected ext input inside whep"), - TransportInput::Close => panic!("Unexpected close input inside whep"), - TransportInput::RpcRes(_, _) => todo!(), + EndpointEvent::RemoteMediaTrack(track, event) => None, } } - fn on_str0m_out<'a>(&mut self, now: Instant, out: Str0mOutput) -> Option> { - match out { - Str0mOutput::Timeout(instance) => { - self.next_tick = Some(instance); - None - } - Str0mOutput::Transmit(out) => { - let from = self.ports.get1(&out.source)?; - return Some(InternalOutput::TransportOutput(TransportOutput::Net(BackendOutgoing::UdpPacket { - slot: *from, - to: out.destination, - data: out.contents.to_vec().into(), - }))); + fn on_transport_rpc_res<'a>(&mut self, now: Instant, req_id: media_server_core::endpoint::EndpointReqId, res: media_server_core::endpoint::EndpointRes) -> Option> { + None + } + + fn on_str0m_event<'a>(&mut self, now: Instant, event: str0m::Event) -> Option> { + match event { + Str0mEvent::Connected => { + log::info!("[TransportWebrtcWhep] connected"); + self.state = State::Connected; + return Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected)))); } - Str0mOutput::Event(event) => match event { - Str0mEvent::Connected => { - self.state = State::Connected; - return Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected)))); - } - Str0mEvent::IceConnectionStateChange(state) => self.on_str0m_state(now, state), - Str0mEvent::MediaAdded(media) => self.on_str0m_media_added(now, media), - Str0mEvent::KeyframeRequest(req) => { - if self.video_mid == Some(req.mid) { - Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::LocalTrack( - VIDEO_TRACK, - LocalTrackEvent::RequestKeyFrame, - )))) - } else { - None - } + Str0mEvent::IceConnectionStateChange(state) => self.on_str0m_state(now, state), + Str0mEvent::MediaAdded(media) => self.on_str0m_media_added(now, media), + Str0mEvent::KeyframeRequest(req) => { + if self.video_mid == Some(req.mid) { + Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::LocalTrack( + VIDEO_TRACK, + LocalTrackEvent::RequestKeyFrame, + )))) + } else { + None } - _ => None, - }, + } + _ => None, } } + + fn close<'a>(&mut self, now: Instant) -> Option> { + self.queue.push_back(InternalOutput::Destroy); + log::info!("[TransportWebrtcWhep] switched to disconnected with close action"); + self.state = State::Disconnected(None); + Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Disconnected(None))))) + } + + fn pop_output<'a>(&mut self, now: Instant) -> Option> { + self.queue.pop_front() + } } impl TransportWebrtcWhep { fn on_str0m_state<'a>(&mut self, now: Instant, state: IceConnectionState) -> Option> { + log::info!("[TransportWebrtcWhep] str0m state changed {:?}", state); + match state { IceConnectionState::New => None, IceConnectionState::Checking => None, - IceConnectionState::Connected | IceConnectionState::Completed => { - if matches!(self.state, State::Reconnecting { at: _ }) { + IceConnectionState::Connected | IceConnectionState::Completed => match &self.state { + State::Connecting { at } => { + log::info!("[TransportWebrtcWhep] switched to connected after {:?}", now - *at); self.state = State::Connected; Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected)))) - } else { + } + State::Reconnecting { at } => { + log::info!("[TransportWebrtcWhep] switched to reconnected after {:?}", now - *at); + self.state = State::Connected; + Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected)))) + } + _ => { + log::warn!("[TransportWebrtcWhep] wrong internal state {:?}", self.state); None } - } + }, IceConnectionState::Disconnected => { if matches!(self.state, State::Connected) { self.state = State::Reconnecting { at: now }; + log::info!("[TransportWebrtcWhep] switched to reconnecting"); return Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Reconnecting)))); } else { return None; @@ -204,19 +202,6 @@ impl TransportWebrtcWhep { )))) } } - - fn on_endpoint_event<'a>(&mut self, now: Instant, event: EndpointEvent) -> Option> { - match event { - EndpointEvent::PeerJoined(_) => None, - EndpointEvent::PeerLeaved(_) => None, - EndpointEvent::PeerTrackStarted(peer, track, meta) => self.try_subscribe(peer, track, meta), - EndpointEvent::PeerTrackStopped(peer, track) => self.try_unsubscribe(peer, track), - EndpointEvent::LocalMediaTrack(track, event) => match event { - EndpointLocalTrackEvent::Media(pkt) => Some(InternalOutput::Str0mSendMedia(self.video_mid?, pkt)), - }, - EndpointEvent::RemoteMediaTrack(track, event) => None, - } - } } impl TransportWebrtcWhep { diff --git a/packages/transport_webrtc/src/transport/whip.rs b/packages/transport_webrtc/src/transport/whip.rs index 82656d05..0c8bdc10 100644 --- a/packages/transport_webrtc/src/transport/whip.rs +++ b/packages/transport_webrtc/src/transport/whip.rs @@ -1,23 +1,20 @@ use std::{ - net::SocketAddr, + collections::VecDeque, time::{Duration, Instant}, }; use media_server_core::{ endpoint::EndpointEvent, - transport::{RemoteTrackEvent, RemoteTrackId, TransportError, TransportEvent, TransportInput, TransportOutput, TransportState}, + transport::{RemoteTrackEvent, RemoteTrackId, TransportError, TransportEvent, TransportOutput, TransportState}, }; -use media_server_utils::Small2dMap; -use sans_io_runtime::backend::{BackendIncoming, BackendOutgoing}; use str0m::{ media::{Direction, KeyframeRequestKind, MediaAdded, MediaKind, Mid}, - net::Protocol, - Event as Str0mEvent, IceConnectionState, Output as Str0mOutput, + Event as Str0mEvent, IceConnectionState, }; use crate::utils::rtp_to_media_packet; -use super::{ExtIn, InternalOutput, TransportWebrtcInternal}; +use super::{InternalOutput, TransportWebrtcInternal}; const TIMEOUT_SEC: u64 = 10; const AUDIO_TRACK: RemoteTrackId = RemoteTrackId(0); @@ -39,25 +36,19 @@ enum TransportWebrtcError { } pub struct TransportWebrtcWhip { - next_tick: Option, state: State, - ports: Small2dMap, audio_mid: Option, video_mid: Option, + queue: VecDeque>, } impl TransportWebrtcWhip { - pub fn new(local_addrs: Vec<(SocketAddr, usize)>) -> Self { - let mut ports = Small2dMap::default(); - for (local_addr, slot) in local_addrs { - ports.insert(local_addr, slot); - } + pub fn new() -> Self { Self { state: State::New, - next_tick: None, - ports, audio_mid: None, video_mid: None, + queue: VecDeque::new(), } } } @@ -89,67 +80,67 @@ impl TransportWebrtcInternal for TransportWebrtcWhip { } _ => {} } - let next_tick = self.next_tick?; - if next_tick > now { - return None; - } - self.next_tick = None; - Some(InternalOutput::Str0mTick(now)) + None } - fn on_transport_input<'a>(&mut self, now: Instant, input: TransportInput<'a, ExtIn>) -> Option> { - match input { - TransportInput::Net(net) => match net { - BackendIncoming::UdpPacket { slot, from, data } => { - let destination = self.ports.get2(&slot)?; - Some(InternalOutput::Str0mReceive(now, Protocol::Udp, from, *destination, data.freeze())) + fn on_endpoint_event<'a>(&mut self, now: Instant, event: EndpointEvent) -> Option> { + match event { + EndpointEvent::PeerJoined(_) => todo!(), + EndpointEvent::PeerLeaved(_) => todo!(), + EndpointEvent::PeerTrackStarted(_, _, _) => todo!(), + EndpointEvent::PeerTrackStopped(_, _) => todo!(), + EndpointEvent::RemoteMediaTrack(_, event) => match event { + media_server_core::endpoint::EndpointRemoteTrackEvent::RequestKeyFrame => { + let mid = self.video_mid?; + Some(InternalOutput::Str0mKeyframe(mid, KeyframeRequestKind::Pli)) + } + media_server_core::endpoint::EndpointRemoteTrackEvent::LimitBitrateBps(bitrate) => { + let mid = self.video_mid?; + Some(InternalOutput::Str0mLimitBitrate(mid, bitrate)) } - _ => panic!("Unexpected input"), }, - TransportInput::Endpoint(event) => self.on_endpoint_event(now, event), - TransportInput::RpcRes(req_id, res) => todo!(), - TransportInput::Ext(_) => panic!("Unexpected ext input inside whip"), - TransportInput::Close => panic!("Unexpected close input inside whip"), + EndpointEvent::LocalMediaTrack(_, _) => todo!(), } } - fn on_str0m_out<'a>(&mut self, now: Instant, out: Str0mOutput) -> Option> { - match out { - Str0mOutput::Timeout(instance) => { - self.next_tick = Some(instance); - None + fn on_transport_rpc_res<'a>(&mut self, now: Instant, req_id: media_server_core::endpoint::EndpointReqId, res: media_server_core::endpoint::EndpointRes) -> Option> { + None + } + + fn on_str0m_event<'a>(&mut self, now: Instant, event: Str0mEvent) -> Option> { + match event { + Str0mEvent::Connected => { + self.state = State::Connected; + return Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected)))); } - Str0mOutput::Transmit(out) => { - let from = self.ports.get1(&out.source)?; - return Some(InternalOutput::TransportOutput(TransportOutput::Net(BackendOutgoing::UdpPacket { - slot: *from, - to: out.destination, - data: out.contents.to_vec().into(), - }))); + Str0mEvent::IceConnectionStateChange(state) => self.on_str0m_state(now, state), + Str0mEvent::MediaAdded(media) => self.on_str0m_media_added(now, media), + Str0mEvent::RtpPacket(pkt) => { + let track = if *pkt.header.payload_type == 111 { + AUDIO_TRACK + } else { + VIDEO_TRACK + }; + let pkt = rtp_to_media_packet(pkt); + Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::RemoteTrack( + track, + RemoteTrackEvent::Media(pkt), + )))) } - Str0mOutput::Event(event) => match event { - Str0mEvent::Connected => { - self.state = State::Connected; - return Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected)))); - } - Str0mEvent::IceConnectionStateChange(state) => self.on_str0m_state(now, state), - Str0mEvent::MediaAdded(media) => self.on_str0m_media_added(now, media), - Str0mEvent::RtpPacket(pkt) => { - let track = if *pkt.header.payload_type == 111 { - AUDIO_TRACK - } else { - VIDEO_TRACK - }; - let pkt = rtp_to_media_packet(pkt); - Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::RemoteTrack( - track, - RemoteTrackEvent::Media(pkt), - )))) - } - _ => None, - }, + _ => None, } } + + fn close<'a>(&mut self, now: Instant) -> Option> { + self.queue.push_back(InternalOutput::Destroy); + log::info!("[TransportWebrtcWhep] switched to disconnected with close action"); + self.state = State::Disconnected(None); + Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Disconnected(None))))) + } + + fn pop_output<'a>(&mut self, now: Instant) -> Option> { + self.queue.pop_front() + } } impl TransportWebrtcWhip { @@ -200,24 +191,4 @@ impl TransportWebrtcWhip { )))) } } - - fn on_endpoint_event<'a>(&mut self, now: Instant, event: EndpointEvent) -> Option> { - match event { - EndpointEvent::PeerJoined(_) => todo!(), - EndpointEvent::PeerLeaved(_) => todo!(), - EndpointEvent::PeerTrackStarted(_, _, _) => todo!(), - EndpointEvent::PeerTrackStopped(_, _) => todo!(), - EndpointEvent::RemoteMediaTrack(_, event) => match event { - media_server_core::endpoint::EndpointRemoteTrackEvent::RequestKeyFrame => { - let mid = self.video_mid?; - Some(InternalOutput::Str0mKeyframe(mid, KeyframeRequestKind::Pli)) - } - media_server_core::endpoint::EndpointRemoteTrackEvent::LimitBitrateBps(bitrate) => { - let mid = self.video_mid?; - Some(InternalOutput::Str0mLimitBitrate(mid, bitrate)) - } - }, - EndpointEvent::LocalMediaTrack(_, _) => todo!(), - } - } } diff --git a/packages/transport_webrtc/src/worker.rs b/packages/transport_webrtc/src/worker.rs index 7e261e5b..6ca3b18e 100644 --- a/packages/transport_webrtc/src/worker.rs +++ b/packages/transport_webrtc/src/worker.rs @@ -91,6 +91,7 @@ impl MediaWorkerWebrtc { match input { GroupInput::Net(BackendIncoming::UdpListenResult { bind: _, result }) => { let (addr, slot) = result.ok()?; + log::info!("[MediaWorkerWebrtc] UdpListenResult {addr}, slot {slot}"); self.addrs.push((addr, slot)); None } @@ -104,6 +105,7 @@ impl MediaWorkerWebrtc { Some(self.process_output(owner.index(), out)) } GroupInput::Ext(owner, ext) => { + log::info!("[MediaWorkerWebrtc] on ext to owner {:?}", owner); let out = self.endpoints.on_event(now, owner.index(), EndpointInput::Ext(ext))?; Some(self.process_output(owner.index(), out)) }