From 41206fe32b0b247ef2454551b54cbdd223401073 Mon Sep 17 00:00:00 2001 From: Giang Minh Date: Mon, 22 Apr 2024 00:14:47 +0700 Subject: [PATCH] WIP: working on simple cluster --- Cargo.lock | 14 ++++ Cargo.toml | 3 + bin/Cargo.toml | 2 +- packages/media_core/src/cluster.rs | 29 +++++--- packages/media_core/src/cluster/room.rs | 1 + packages/media_core/src/endpoint.rs | 9 ++- packages/media_core/src/endpoint/internal.rs | 74 ++++++++++++------- .../src/endpoint/internal/local_track.rs | 21 +++++- .../src/endpoint/internal/remote_track.rs | 20 ++++- packages/media_core/src/transport.rs | 10 ++- packages/media_runner/Cargo.toml | 1 + packages/media_runner/src/worker.rs | 23 ++++-- packages/media_utils/src/small_2dmap.rs | 2 +- packages/protocol/Cargo.toml | 4 +- packages/protocol/src/endpoint.rs | 7 +- packages/protocol/src/media.rs | 6 +- packages/transport_webrtc/src/lib.rs | 2 +- packages/transport_webrtc/src/transport.rs | 25 +++++-- .../transport_webrtc/src/transport/whep.rs | 25 +++---- .../transport_webrtc/src/transport/whip.rs | 36 ++++++--- packages/transport_webrtc/src/worker.rs | 6 +- 21 files changed, 218 insertions(+), 102 deletions(-) create mode 100644 packages/media_core/src/cluster/room.rs diff --git a/Cargo.lock b/Cargo.lock index 13eb9c5d..5a062f53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -616,6 +616,17 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "derive_more" version = "0.99.17" @@ -1189,6 +1200,8 @@ name = "media-server-protocol" version = "0.1.0" dependencies = [ "convert-enum", + "derivative", + "derive_more", ] [[package]] @@ -1200,6 +1213,7 @@ dependencies = [ "log", "media-server-core", "media-server-protocol", + "rand", "sans-io-runtime", "transport-webrtc", ] diff --git a/Cargo.toml b/Cargo.toml index bc4ef8e3..ecf5ab6f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,3 +15,6 @@ convert-enum = "0.1" num_enum = "0.7" log = "0.4" smallmap = "1.4" +derivative = "2.2" +derive_more = "0.99" +rand = "0.8" diff --git a/bin/Cargo.toml b/bin/Cargo.toml index 19aaae0c..f9ccbf62 100644 --- a/bin/Cargo.toml +++ b/bin/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" tracing-subscriber = { version = "0.3", features = ["env-filter", "std"] } clap = { version = "4.5", features = ["env", "derive"] } log = { workspace = true } +rand = { workspace = true } poem = { version = "3.0", features = ["static-files"] } poem-openapi = { version = "5.0", features = ["swagger-ui"] } tokio = { version = "1.37", features = ["full"] } @@ -16,4 +17,3 @@ sans-io-runtime = { workspace = true } atm0s-sdn = { workspace = true } media-server-protocol = { path = "../packages/protocol" } media-server-runner = { path = "../packages/media_runner" } -rand = "0.8.5" diff --git a/packages/media_core/src/cluster.rs b/packages/media_core/src/cluster.rs index 7284e30d..22eeb1b0 100644 --- a/packages/media_core/src/cluster.rs +++ b/packages/media_core/src/cluster.rs @@ -1,14 +1,17 @@ -use std::{marker::PhantomData, time::Instant}; +use std::{fmt::Debug, hash::Hash, time::Instant}; use atm0s_sdn::features::{FeaturesControl, FeaturesEvent}; use media_server_protocol::{ endpoint::{PeerId, RoomId, TrackMeta, TrackName}, media::MediaPacket, }; +use media_server_utils::Small2dMap; use crate::transport::{LocalTrackId, RemoteTrackId}; -#[derive(Clone)] +mod room; + +#[derive(Debug, Clone)] pub enum ClusterRemoteTrackControl { Started(TrackName), Media(MediaPacket), @@ -20,20 +23,21 @@ pub enum ClusterRemoteTrackEvent { RequestKeyFrame, } -#[derive(Clone)] +#[derive(Debug, Clone)] pub enum ClusterLocalTrackControl { Subscribe(PeerId, TrackName), RequestKeyFrame, Unsubscribe, } -#[derive(Clone)] +#[derive(Debug, Clone)] pub enum ClusterLocalTrackEvent { Started, Media(MediaPacket), Ended, } +#[derive(Debug)] pub enum ClusterEndpointControl { JoinRoom(RoomId, PeerId), LeaveRoom, @@ -65,25 +69,28 @@ pub enum Output { Endpoint(Vec, ClusterEndpointEvent), } -#[derive(Debug)] pub struct MediaCluster { - _tmp: PhantomData, + endpoints: Small2dMap, } -impl Default for MediaCluster { +impl Default for MediaCluster { fn default() -> Self { - Self { _tmp: PhantomData } + Self { endpoints: Small2dMap::default() } } } -impl MediaCluster { +impl MediaCluster { pub fn on_tick(&mut self, now: Instant) -> Option> { //TODO None } - pub fn on_input(&mut self, now: Instant, input: Input) -> Option> { - //TODO + pub fn on_sdn_event(&mut self, now: Instant, event: FeaturesEvent) -> Option> { + None + } + + pub fn on_endpoint_control(&mut self, now: Instant, owner: Owner, control: ClusterEndpointControl) -> Option> { + log::info!("[MediaCluster] {:?} control {:?}", owner, control); None } diff --git a/packages/media_core/src/cluster/room.rs b/packages/media_core/src/cluster/room.rs new file mode 100644 index 00000000..ec33b617 --- /dev/null +++ b/packages/media_core/src/cluster/room.rs @@ -0,0 +1 @@ +pub struct ClusterRoom {} diff --git a/packages/media_core/src/endpoint.rs b/packages/media_core/src/endpoint.rs index 44f0b599..fa9140fb 100644 --- a/packages/media_core/src/endpoint.rs +++ b/packages/media_core/src/endpoint.rs @@ -37,6 +37,11 @@ pub enum EndpointLocalTrackRes { } pub struct EndpointReqId(pub u64); +impl From for EndpointReqId { + fn from(value: u64) -> Self { + Self(value) + } +} /// This is control APIs, which is used to control server from Endpoint SDK pub enum EndpointReq { @@ -85,7 +90,7 @@ pub enum EndpointOutput<'a, Ext> { Net(BackendOutgoing<'a>), Cluster(ClusterEndpointControl), Ext(Ext), - Shutdown, + Destroy, } #[derive(num_enum::TryFromPrimitive, num_enum::IntoPrimitive)] @@ -189,7 +194,6 @@ impl, ExtIn, ExtOut> Endpoint { let out = self.internal.on_transport_rpc(now, req_id, req)?; self.process_internal_output(now, out) } - TransportOutput::Destroy => Some(EndpointOutput::Shutdown), } } @@ -205,6 +209,7 @@ impl, ExtIn, ExtOut> Endpoint { self.process_transport_output(now, out) } InternalOutput::Cluster(control) => Some(EndpointOutput::Cluster(control)), + InternalOutput::Destroy => Some(EndpointOutput::Destroy), } } } diff --git a/packages/media_core/src/endpoint/internal.rs b/packages/media_core/src/endpoint/internal.rs index e79cbbde..03c0231b 100644 --- a/packages/media_core/src/endpoint/internal.rs +++ b/packages/media_core/src/endpoint/internal.rs @@ -1,13 +1,10 @@ -use std::{ - collections::{HashMap, VecDeque}, - time::Instant, -}; +use std::{collections::VecDeque, time::Instant}; use media_server_protocol::endpoint::{PeerId, RoomId}; use crate::{ cluster::{ClusterEndpointControl, ClusterEndpointEvent, ClusterLocalTrackEvent, ClusterRemoteTrackEvent}, - transport::{LocalTrackEvent, LocalTrackId, RemoteTrackEvent, RemoteTrackId, TransportEvent, TransportInput, TransportState, TransportStats}, + transport::{LocalTrackEvent, LocalTrackId, RemoteTrackEvent, RemoteTrackId, TransportEvent, TransportState, TransportStats}, }; use self::{local_track::EndpointLocalTrack, remote_track::EndpointRemoteTrack}; @@ -21,6 +18,7 @@ pub enum InternalOutput { Event(EndpointEvent), RpcRes(EndpointReqId, EndpointRes), Cluster(ClusterEndpointControl), + Destroy, } pub struct EndpointInternal { @@ -73,17 +71,21 @@ impl EndpointInternal { EndpointReq::JoinRoom(room, peer) => { self.room = Some((room.clone(), peer.clone())); if matches!(self.state, TransportState::Connecting) { - Some(InternalOutput::Cluster(ClusterEndpointControl::JoinRoom(room, peer))) - } else { + log::info!("[EndpointInternal] join_room({room}, {peer}) but in Connecting state => wait"); None + } else { + log::info!("[EndpointInternal] join_room({room}, {peer})"); + Some(InternalOutput::Cluster(ClusterEndpointControl::JoinRoom(room, peer))) } } EndpointReq::LeaveRoom => { - self.room.take()?; + let (room, peer) = self.room.take()?; if matches!(self.state, TransportState::Connecting) { - Some(InternalOutput::Cluster(ClusterEndpointControl::LeaveRoom)) - } else { + log::info!("[EndpointInternal] leave_room({room}, {peer}) but in Connecting state => only clear local"); None + } else { + log::info!("[EndpointInternal] leave_room({room}, {peer})"); + Some(InternalOutput::Cluster(ClusterEndpointControl::LeaveRoom)) } } EndpointReq::RemoteTrack(track, control) => todo!(), @@ -94,9 +96,16 @@ impl EndpointInternal { fn on_transport_state_changed<'a>(&mut self, now: Instant, state: TransportState) -> Option { self.state = state; match &self.state { - TransportState::Connecting => None, - TransportState::ConnectError(_) => None, + TransportState::Connecting => { + log::info!("[EndpointInternal] connecting"); + None + } + TransportState::ConnectError(err) => { + log::info!("[EndpointInternal] connect error"); + Some(InternalOutput::Destroy) + } TransportState::Connected => { + log::info!("[EndpointInternal] connected"); for i in 0..self.local_tracks_id.len() { let id = self.local_tracks_id[i]; if let Some(out) = self.local_tracks.get_mut(&id).expect("Should have").on_connected(now) { @@ -114,12 +123,25 @@ impl EndpointInternal { } } } - let (room, peer) = self.room.as_ref()?; - self.queue.push_back(InternalOutput::Cluster(ClusterEndpointControl::JoinRoom(room.clone(), peer.clone()))); + if let Some((room, peer)) = self.room.as_ref() { + log::info!("[EndpointInternal] join_room({room}, {peer}) after connected"); + self.queue.push_back(InternalOutput::Cluster(ClusterEndpointControl::JoinRoom(room.clone(), peer.clone()))); + } + self.queue.pop_front() + } + TransportState::Reconnecting => { + log::info!("[EndpointInternal] reconnecting"); + None + } + TransportState::Disconnected(err) => { + log::info!("[EndpointInternal] disconnected {:?}", err); + if let Some((room, peer)) = &self.room { + log::info!("[EndpointInternal] leave_room({room}, {peer}) after disconnected"); + self.queue.push_back(InternalOutput::Cluster(ClusterEndpointControl::LeaveRoom)); + } + self.queue.push_back(InternalOutput::Destroy); self.queue.pop_front() } - TransportState::Reconnecting => None, - TransportState::Disconnected(_) => None, } } @@ -128,7 +150,7 @@ impl EndpointInternal { self.remote_tracks_id.push(track); self.remote_tracks.insert(track, EndpointRemoteTrack::default()); } - let out = self.remote_tracks.get_mut(&track)?.on_event(now, event)?; + let out = self.remote_tracks.get_mut(&track)?.on_transport_event(now, event)?; self.on_remote_track_output(now, track, out) } @@ -137,7 +159,7 @@ impl EndpointInternal { self.local_tracks_id.push(track); self.local_tracks.insert(track, EndpointLocalTrack::default()); } - let out = self.local_tracks.get_mut(&track)?.on_event(now, event)?; + let out = self.local_tracks.get_mut(&track)?.on_transport_event(now, event)?; self.on_local_track_output(now, track, out) } @@ -160,25 +182,25 @@ impl EndpointInternal { } fn on_cluster_remote_track<'a>(&mut self, now: Instant, id: RemoteTrackId, event: ClusterRemoteTrackEvent) -> Option { - match event { - _ => todo!(), - } + None } fn on_cluster_local_track<'a>(&mut self, now: Instant, id: LocalTrackId, event: ClusterLocalTrackEvent) -> Option { - match event { - _ => todo!(), - } + None } } /// This block for internal local and remote track impl EndpointInternal { fn on_remote_track_output<'a>(&mut self, now: Instant, id: RemoteTrackId, out: remote_track::Output) -> Option { - todo!() + match out { + remote_track::Output::Cluster(control) => Some(InternalOutput::Cluster(ClusterEndpointControl::RemoteTrack(id, control))), + } } fn on_local_track_output<'a>(&mut self, now: Instant, id: LocalTrackId, out: local_track::Output) -> Option { - todo!() + match out { + local_track::Output::Cluster(control) => Some(InternalOutput::Cluster(ClusterEndpointControl::LocalTrack(id, control))), + } } } diff --git a/packages/media_core/src/endpoint/internal/local_track.rs b/packages/media_core/src/endpoint/internal/local_track.rs index 4f4bed99..b532135e 100644 --- a/packages/media_core/src/endpoint/internal/local_track.rs +++ b/packages/media_core/src/endpoint/internal/local_track.rs @@ -1,8 +1,13 @@ use std::time::Instant; -use crate::transport::{LocalTrackEvent, LocalTrackId}; +use crate::{ + cluster::ClusterLocalTrackControl, + transport::{LocalTrackEvent, LocalTrackId}, +}; -pub enum Output {} +pub enum Output { + Cluster(ClusterLocalTrackControl), +} #[derive(Default)] pub struct EndpointLocalTrack {} @@ -11,8 +16,16 @@ impl EndpointLocalTrack { pub fn on_connected(&mut self, now: Instant) -> Option { None } - pub fn on_event(&mut self, now: Instant, event: LocalTrackEvent) -> Option { - None + pub fn on_transport_event(&mut self, now: Instant, event: LocalTrackEvent) -> Option { + log::info!("[EndpointLocalTrack] on event {:?}", event); + match event { + LocalTrackEvent::Started => None, + //TODO maybe switch is RPC type + LocalTrackEvent::Switch(Some((peer, track))) => Some(Output::Cluster(ClusterLocalTrackControl::Subscribe(peer, track))), + LocalTrackEvent::Switch(None) => Some(Output::Cluster(ClusterLocalTrackControl::Unsubscribe)), + LocalTrackEvent::RequestKeyFrame => Some(Output::Cluster(ClusterLocalTrackControl::RequestKeyFrame)), + LocalTrackEvent::Ended => None, + } } pub fn pop_output(&mut self) -> Option { None diff --git a/packages/media_core/src/endpoint/internal/remote_track.rs b/packages/media_core/src/endpoint/internal/remote_track.rs index 701d2801..bd2905fd 100644 --- a/packages/media_core/src/endpoint/internal/remote_track.rs +++ b/packages/media_core/src/endpoint/internal/remote_track.rs @@ -1,8 +1,12 @@ use std::time::Instant; -use crate::transport::{RemoteTrackEvent, RemoteTrackId}; +use media_server_protocol::endpoint::TrackName; -pub enum Output {} +use crate::{cluster::ClusterRemoteTrackControl, transport::RemoteTrackEvent}; + +pub enum Output { + Cluster(ClusterRemoteTrackControl), +} #[derive(Default)] pub struct EndpointRemoteTrack {} @@ -11,9 +15,17 @@ impl EndpointRemoteTrack { pub fn on_connected(&mut self, now: Instant) -> Option { None } - pub fn on_event(&mut self, now: Instant, event: RemoteTrackEvent) -> Option { - None + + pub fn on_transport_event(&mut self, now: Instant, event: RemoteTrackEvent) -> Option { + match event { + RemoteTrackEvent::Started { name } => Some(Output::Cluster(ClusterRemoteTrackControl::Started(TrackName(name)))), + RemoteTrackEvent::Paused => None, + RemoteTrackEvent::Resumed => None, + RemoteTrackEvent::Media(_) => None, + RemoteTrackEvent::Ended => Some(Output::Cluster(ClusterRemoteTrackControl::Ended)), + } } + pub fn pop_output(&mut self) -> Option { None } diff --git a/packages/media_core/src/transport.rs b/packages/media_core/src/transport.rs index e494d4df..4056ff38 100644 --- a/packages/media_core/src/transport.rs +++ b/packages/media_core/src/transport.rs @@ -9,11 +9,11 @@ use sans_io_runtime::backend::{BackendIncoming, BackendOutgoing}; use crate::endpoint::{EndpointEvent, EndpointReq, EndpointReqId, EndpointRes}; -#[derive(Clone, Copy)] +#[derive(Debug, Clone, Copy)] pub struct TransportId(pub u64); /// RemoteTrackId is used for track which received media from client -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RemoteTrackId(pub u16); impl Hash for RemoteTrackId { @@ -23,7 +23,7 @@ impl Hash for RemoteTrackId { } /// LocalTrackId is used for track which send media to client -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct LocalTrackId(pub u16); impl Hash for LocalTrackId { @@ -32,6 +32,7 @@ impl Hash for LocalTrackId { } } +#[derive(Debug)] pub enum TransportError { Timeout, } @@ -52,6 +53,7 @@ pub struct TransportStats { } /// This is used for notifying state of local track to endpoint +#[derive(Debug)] pub enum LocalTrackEvent { Started, Switch(Option<(PeerId, TrackName)>), @@ -66,6 +68,7 @@ impl LocalTrackEvent { } /// This is used for notifying state of remote track to endpoint +#[derive(Debug)] pub enum RemoteTrackEvent { Started { name: String }, Paused, @@ -102,7 +105,6 @@ pub enum TransportOutput<'a, Ext> { Event(TransportEvent), RpcReq(EndpointReqId, EndpointReq), Ext(Ext), - Destroy, } pub trait Transport { diff --git a/packages/media_runner/Cargo.toml b/packages/media_runner/Cargo.toml index 49af27d0..9897f390 100644 --- a/packages/media_runner/Cargo.toml +++ b/packages/media_runner/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +rand = { workspace = true } log = { workspace = true } convert-enum = { workspace = true } media-server-protocol = { path = "../protocol" } diff --git a/packages/media_runner/src/worker.rs b/packages/media_runner/src/worker.rs index dac4b866..0e03a0d9 100644 --- a/packages/media_runner/src/worker.rs +++ b/packages/media_runner/src/worker.rs @@ -7,6 +7,7 @@ use media_server_protocol::transport::{ whip::{self, WhipConnectRes, WhipDeleteRes, WhipRemoteIceRes}, RpcReq, RpcRes, }; +use rand::random; use sans_io_runtime::{ backend::{BackendIncoming, BackendOutgoing}, TaskSwitcher, @@ -63,7 +64,7 @@ impl TryFrom for TaskType { } } -#[derive(convert_enum::From)] +#[derive(convert_enum::From, Debug, Clone, Hash, PartialEq, Eq)] enum MediaClusterOwner { Webrtc(WebrtcOwner), } @@ -207,7 +208,7 @@ impl MediaServerWorker { SdnWorkerOutput::Ext(out) => Output::ExtSdn(out), SdnWorkerOutput::ExtWorker(out) => match out { SdnExtOut::FeaturesEvent(e) => { - if let Some(out) = self.media_cluster.on_input(now, cluster::Input::Sdn(e)) { + if let Some(out) = self.media_cluster.on_sdn_event(now, e) { self.output_cluster(now, out) } else { Output::Continue @@ -259,7 +260,7 @@ impl MediaServerWorker { match out { transport_webrtc::GroupOutput::Net(out) => Output::Net(Owner::MediaWebrtc, out), transport_webrtc::GroupOutput::Cluster(owner, control) => { - if let Some(out) = self.media_cluster.on_input(now, cluster::Input::Endpoint(owner.into(), control)) { + if let Some(out) = self.media_cluster.on_endpoint_control(now, owner.into(), control) { self.output_cluster(now, out) } else { Output::Continue @@ -284,7 +285,10 @@ impl MediaServerWorker { log::info!("[MediaServerWorker] incoming rpc req {req_id}"); match req { RpcReq::Whip(req) => match req { - whip::RpcReq::Connect(req) => match self.media_webrtc.spawn(transport_webrtc::Variant::Whip, &req.sdp) { + whip::RpcReq::Connect(req) => match self + .media_webrtc + .spawn(transport_webrtc::VariantParams::Whip(req.token.into(), "publisher".to_string().into()), &req.sdp) + { Ok((sdp, conn_id)) => Some(Output::ExtRpc(req_id, RpcRes::Whip(whip::RpcRes::Connect(Ok(WhipConnectRes { conn_id, sdp }))))), Err(e) => Some(Output::ExtRpc(req_id, RpcRes::Whip(whip::RpcRes::Connect(Err(e))))), }, @@ -304,10 +308,13 @@ impl MediaServerWorker { } }, RpcReq::Whep(req) => match req { - whep::RpcReq::Connect(req) => match self.media_webrtc.spawn(transport_webrtc::Variant::Whep, &req.sdp) { - Ok((sdp, conn_id)) => Some(Output::ExtRpc(req_id, RpcRes::Whep(whep::RpcRes::Connect(Ok(WhepConnectRes { conn_id, sdp }))))), - Err(e) => Some(Output::ExtRpc(req_id, RpcRes::Whep(whep::RpcRes::Connect(Err(e))))), - }, + whep::RpcReq::Connect(req) => { + let peer_id = format!("whep-{}", random::()); + match self.media_webrtc.spawn(transport_webrtc::VariantParams::Whep(req.token.into(), peer_id.into()), &req.sdp) { + Ok((sdp, conn_id)) => Some(Output::ExtRpc(req_id, RpcRes::Whep(whep::RpcRes::Connect(Ok(WhepConnectRes { conn_id, sdp }))))), + Err(e) => Some(Output::ExtRpc(req_id, RpcRes::Whep(whep::RpcRes::Connect(Err(e))))), + } + } whep::RpcReq::RemoteIce(req) => { log::info!("on rpc request {req_id}, whep::RpcReq::RemoteIce"); let out = self.media_webrtc.on_event( diff --git a/packages/media_utils/src/small_2dmap.rs b/packages/media_utils/src/small_2dmap.rs index 3ce70793..84dde954 100644 --- a/packages/media_utils/src/small_2dmap.rs +++ b/packages/media_utils/src/small_2dmap.rs @@ -1,6 +1,6 @@ use std::hash::Hash; -pub struct Small2dMap { +pub struct Small2dMap { data: smallmap::Map, reverse: smallmap::Map, } diff --git a/packages/protocol/Cargo.toml b/packages/protocol/Cargo.toml index d0bd0f8b..0e79eea1 100644 --- a/packages/protocol/Cargo.toml +++ b/packages/protocol/Cargo.toml @@ -6,4 +6,6 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -convert-enum = "0.1.0" +convert-enum = { workspace = true } +derivative = { workspace = true } +derive_more = { workspace = true } diff --git a/packages/protocol/src/endpoint.rs b/packages/protocol/src/endpoint.rs index 6818fd56..88c5b8bc 100644 --- a/packages/protocol/src/endpoint.rs +++ b/packages/protocol/src/endpoint.rs @@ -1,3 +1,4 @@ +use derive_more::{AsRef, From}; use std::{fmt::Display, str::FromStr}; use crate::{ @@ -100,13 +101,13 @@ impl ConnLayer for usize { } } -#[derive(Clone, PartialEq, Eq)] +#[derive(From, AsRef, Debug, derive_more::Display, Clone, PartialEq, Eq, Hash)] pub struct RoomId(pub String); -#[derive(Clone, PartialEq, Eq)] +#[derive(From, AsRef, Debug, derive_more::Display, Clone, PartialEq, Eq, Hash)] pub struct PeerId(pub String); -#[derive(Clone, PartialEq, Eq)] +#[derive(From, AsRef, Debug, derive_more::Display, Clone, PartialEq, Eq, Hash)] pub struct TrackName(pub String); #[derive(Clone)] diff --git a/packages/protocol/src/media.rs b/packages/protocol/src/media.rs index 0ec90bbb..30b04a5f 100644 --- a/packages/protocol/src/media.rs +++ b/packages/protocol/src/media.rs @@ -1,3 +1,5 @@ +use derivative::Derivative; + #[derive(Clone)] pub enum MediaKind { Audio, @@ -29,12 +31,14 @@ pub enum MediaCodec { Vp9, } -#[derive(Clone)] +#[derive(Derivative, Clone)] +#[derivative(Debug)] pub struct MediaPacket { pub pt: u8, pub ts: u32, pub seq: u64, pub marker: bool, pub nackable: bool, + #[derivative(Debug = "ignore")] pub data: Vec, } diff --git a/packages/transport_webrtc/src/lib.rs b/packages/transport_webrtc/src/lib.rs index 2ec06e82..67fe8cab 100644 --- a/packages/transport_webrtc/src/lib.rs +++ b/packages/transport_webrtc/src/lib.rs @@ -3,7 +3,7 @@ mod transport; mod utils; mod worker; -pub use transport::{ExtIn, ExtOut, Variant}; +pub use transport::{ExtIn, ExtOut, Variant, VariantParams}; pub use worker::{GroupInput, GroupOutput, MediaWorkerWebrtc, WebrtcOwner}; #[derive(num_enum::TryFromPrimitive, num_enum::IntoPrimitive)] diff --git a/packages/transport_webrtc/src/transport.rs b/packages/transport_webrtc/src/transport.rs index 37c4b34d..1ade7731 100644 --- a/packages/transport_webrtc/src/transport.rs +++ b/packages/transport_webrtc/src/transport.rs @@ -5,6 +5,7 @@ use media_server_core::{ transport::{Transport, TransportInput, TransportOutput}, }; use media_server_protocol::{ + endpoint::{PeerId, RoomId}, media::MediaPacket, transport::{RpcError, RpcResult}, }; @@ -25,6 +26,12 @@ use crate::WebrtcError; mod whep; mod whip; +pub enum VariantParams { + Whip(RoomId, PeerId), + Whep(RoomId, PeerId), + Sdk, +} + pub enum Variant { Whip, Whep, @@ -44,7 +51,6 @@ enum InternalOutput<'a> { Str0mLimitBitrate(Mid, u64), Str0mSendMedia(Mid, MediaPacket), TransportOutput(TransportOutput<'a, ExtOut>), - Destroy, } trait TransportWebrtcInternal { @@ -64,7 +70,7 @@ pub struct TransportWebrtc { } impl TransportWebrtc { - pub fn new(variant: Variant, offer: &str, dtls_cert: DtlsCert, local_addrs: Vec<(SocketAddr, usize)>) -> RpcResult<(Self, String, String)> { + pub fn new(variant: VariantParams, offer: &str, dtls_cert: DtlsCert, local_addrs: Vec<(SocketAddr, usize)>) -> RpcResult<(Self, String, String)> { let offer = SdpOffer::from_sdp_string(offer).map_err(|_e| RpcError::new2(WebrtcError::SdpError))?; let rtc_config = Rtc::builder().set_rtp_mode(true).set_ice_lite(true).set_dtls_cert(dtls_cert).set_local_ice_credentials(IceCreds::new()); let ice_ufrag = rtc_config.local_ice_credentials().as_ref().expect("should have ice credentials").ufrag.clone(); @@ -84,9 +90,9 @@ impl TransportWebrtc { next_tick: None, rtc, internal: match variant { - Variant::Whip => Box::new(whip::TransportWebrtcWhip::new()), - Variant::Whep => Box::new(whep::TransportWebrtcWhep::new()), - Variant::Sdk => unimplemented!(), + VariantParams::Whip(room, peer) => Box::new(whip::TransportWebrtcWhip::new(room, peer)), + VariantParams::Whep(room, peer) => Box::new(whep::TransportWebrtcWhep::new(room, peer)), + VariantParams::Sdk => unimplemented!(), }, ports, }, @@ -114,7 +120,6 @@ impl TransportWebrtc { self.pop_event(now) } InternalOutput::TransportOutput(out) => Some(out), - InternalOutput::Destroy => Some(TransportOutput::Destroy), } } } @@ -161,9 +166,13 @@ impl Transport for TransportWebrtc { } }, TransportInput::Close => { - self.internal.close(now); + log::info!("[TransportWebrtc] close request"); self.rtc.disconnect(); - self.pop_event(now) + if let Some(out) = self.internal.close(now) { + self.process_internal_output(now, out) + } else { + self.pop_event(now) + } } } } diff --git a/packages/transport_webrtc/src/transport/whep.rs b/packages/transport_webrtc/src/transport/whep.rs index 78639f7a..a4447565 100644 --- a/packages/transport_webrtc/src/transport/whep.rs +++ b/packages/transport_webrtc/src/transport/whep.rs @@ -4,10 +4,10 @@ use std::{ }; use media_server_core::{ - endpoint::{EndpointEvent, EndpointLocalTrackEvent}, + endpoint::{EndpointEvent, EndpointLocalTrackEvent, EndpointReq}, transport::{LocalTrackEvent, LocalTrackId, TransportError, TransportEvent, TransportOutput, TransportState}, }; -use media_server_protocol::endpoint::{PeerId, TrackMeta, TrackName}; +use media_server_protocol::endpoint::{PeerId, RoomId, TrackMeta, TrackName}; use str0m::{ media::{Direction, MediaAdded, MediaKind, Mid}, Event as Str0mEvent, IceConnectionState, @@ -42,6 +42,8 @@ struct SubscribeStreams { } pub struct TransportWebrtcWhep { + room: RoomId, + peer: PeerId, state: State, audio_mid: Option, video_mid: Option, @@ -50,8 +52,10 @@ pub struct TransportWebrtcWhep { } impl TransportWebrtcWhep { - pub fn new() -> Self { + pub fn new(room: RoomId, peer: PeerId) -> Self { Self { + room, + peer, state: State::New, audio_mid: None, video_mid: None, @@ -113,6 +117,10 @@ impl TransportWebrtcInternal for TransportWebrtcWhep { Str0mEvent::Connected => { log::info!("[TransportWebrtcWhep] connected"); self.state = State::Connected; + self.queue.push_back(InternalOutput::TransportOutput(TransportOutput::RpcReq( + 0.into(), + EndpointReq::JoinRoom(self.room.clone(), self.peer.clone()), + ))); return Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected)))); } Str0mEvent::IceConnectionStateChange(state) => self.on_str0m_state(now, state), @@ -132,7 +140,6 @@ impl TransportWebrtcInternal for TransportWebrtcWhep { } fn close<'a>(&mut self, now: Instant) -> Option> { - self.queue.push_back(InternalOutput::Destroy); log::info!("[TransportWebrtcWhep] switched to disconnected with close action"); self.state = State::Disconnected(None); Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Disconnected(None))))) @@ -151,20 +158,12 @@ impl TransportWebrtcWhep { IceConnectionState::New => None, IceConnectionState::Checking => None, IceConnectionState::Connected | IceConnectionState::Completed => match &self.state { - State::Connecting { at } => { - log::info!("[TransportWebrtcWhep] switched to connected after {:?}", now - *at); - self.state = State::Connected; - Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected)))) - } State::Reconnecting { at } => { log::info!("[TransportWebrtcWhep] switched to reconnected after {:?}", now - *at); self.state = State::Connected; Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected)))) } - _ => { - log::warn!("[TransportWebrtcWhep] wrong internal state {:?}", self.state); - None - } + _ => None, }, IceConnectionState::Disconnected => { if matches!(self.state, State::Connected) { diff --git a/packages/transport_webrtc/src/transport/whip.rs b/packages/transport_webrtc/src/transport/whip.rs index 0c8bdc10..fa4dc565 100644 --- a/packages/transport_webrtc/src/transport/whip.rs +++ b/packages/transport_webrtc/src/transport/whip.rs @@ -4,9 +4,10 @@ use std::{ }; use media_server_core::{ - endpoint::EndpointEvent, + endpoint::{EndpointEvent, EndpointReq}, transport::{RemoteTrackEvent, RemoteTrackId, TransportError, TransportEvent, TransportOutput, TransportState}, }; +use media_server_protocol::endpoint::{PeerId, RoomId}; use str0m::{ media::{Direction, KeyframeRequestKind, MediaAdded, MediaKind, Mid}, Event as Str0mEvent, IceConnectionState, @@ -22,6 +23,7 @@ const AUDIO_NAME: &str = "audio_main"; const VIDEO_TRACK: RemoteTrackId = RemoteTrackId(1); const VIDEO_NAME: &str = "video_main"; +#[derive(Debug)] enum State { New, Connecting { at: Instant }, @@ -31,11 +33,14 @@ enum State { Disconnected(Option), } +#[derive(Debug)] enum TransportWebrtcError { Timeout, } pub struct TransportWebrtcWhip { + room: RoomId, + peer: PeerId, state: State, audio_mid: Option, video_mid: Option, @@ -43,8 +48,10 @@ pub struct TransportWebrtcWhip { } impl TransportWebrtcWhip { - pub fn new() -> Self { + pub fn new(room: RoomId, peer: PeerId) -> Self { Self { + room, + peer, state: State::New, audio_mid: None, video_mid: None, @@ -62,7 +69,7 @@ impl TransportWebrtcInternal for TransportWebrtcWhip { } State::Connecting { at } => { if now - *at >= Duration::from_secs(TIMEOUT_SEC) { - log::info!("Connect timed out after {:?}", now - *at); + log::info!("[TransportWebrtcWhip] connect timed out after {:?} => switched to ConnectError", now - *at); self.state = State::ConnectError(TransportWebrtcError::Timeout); return Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::ConnectError( TransportError::Timeout, @@ -71,7 +78,7 @@ impl TransportWebrtcInternal for TransportWebrtcWhip { } State::Reconnecting { at } => { if now - *at >= Duration::from_secs(TIMEOUT_SEC) { - log::info!("Reconnecting timed out after {:?}", now - *at); + log::info!("[TransportWebrtcWhip] reconnect timed out after {:?} => switched to Disconnected", now - *at); self.state = State::Disconnected(Some(TransportWebrtcError::Timeout)); return Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Disconnected(Some( TransportError::Timeout, @@ -111,6 +118,11 @@ impl TransportWebrtcInternal for TransportWebrtcWhip { match event { Str0mEvent::Connected => { self.state = State::Connected; + log::info!("[TransportWebrtcWhip] connected"); + self.queue.push_back(InternalOutput::TransportOutput(TransportOutput::RpcReq( + 0.into(), + EndpointReq::JoinRoom(self.room.clone(), self.peer.clone()), + ))); return Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected)))); } Str0mEvent::IceConnectionStateChange(state) => self.on_str0m_state(now, state), @@ -132,8 +144,7 @@ impl TransportWebrtcInternal for TransportWebrtcWhip { } fn close<'a>(&mut self, now: Instant) -> Option> { - self.queue.push_back(InternalOutput::Destroy); - log::info!("[TransportWebrtcWhep] switched to disconnected with close action"); + log::info!("[TransportWebrtcWhip] switched to disconnected with close action"); self.state = State::Disconnected(None); Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Disconnected(None))))) } @@ -145,20 +156,23 @@ impl TransportWebrtcInternal for TransportWebrtcWhip { impl TransportWebrtcWhip { fn on_str0m_state<'a>(&mut self, now: Instant, state: IceConnectionState) -> Option> { + log::info!("[TransportWebrtcWhip] str0m state changed {:?}", state); + match state { IceConnectionState::New => None, IceConnectionState::Checking => None, - IceConnectionState::Connected | IceConnectionState::Completed => { - if matches!(self.state, State::Reconnecting { at: _ }) { + IceConnectionState::Connected | IceConnectionState::Completed => match &self.state { + State::Reconnecting { at } => { + log::info!("[TransportWebrtcWhip] switched to reconnected after {:?}", now - *at); self.state = State::Connected; Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected)))) - } else { - None } - } + _ => None, + }, IceConnectionState::Disconnected => { if matches!(self.state, State::Connected) { self.state = State::Reconnecting { at: now }; + log::info!("[TransportWebrtcWhip] switched to reconnecting"); return Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Reconnecting)))); } else { return None; diff --git a/packages/transport_webrtc/src/worker.rs b/packages/transport_webrtc/src/worker.rs index 6ca3b18e..57a6d033 100644 --- a/packages/transport_webrtc/src/worker.rs +++ b/packages/transport_webrtc/src/worker.rs @@ -13,7 +13,7 @@ use str0m::change::DtlsCert; use crate::{ shared_port::SharedUdpPort, - transport::{ExtIn, ExtOut, TransportWebrtc, Variant}, + transport::{ExtIn, ExtOut, TransportWebrtc, VariantParams}, }; group_task!(Endpoints, Endpoint, EndpointInput<'a, ExtIn>, EndpointOutput<'a, ExtOut>); @@ -53,7 +53,7 @@ impl MediaWorkerWebrtc { } } - pub fn spawn(&mut self, variant: Variant, offer: &str) -> RpcResult<(String, usize)> { + pub fn spawn(&mut self, variant: VariantParams, offer: &str) -> RpcResult<(String, usize)> { let (tran, ufrag, sdp) = TransportWebrtc::new(variant, offer, self.dtls_cert.clone(), self.addrs.clone())?; let endpoint = Endpoint::new(tran); let index = self.endpoints.add_task(endpoint); @@ -65,7 +65,7 @@ impl MediaWorkerWebrtc { match out { EndpointOutput::Net(net) => GroupOutput::Net(net), EndpointOutput::Cluster(control) => GroupOutput::Cluster(WebrtcOwner(index), control), - EndpointOutput::Shutdown => { + EndpointOutput::Destroy => { self.endpoints.remove_task(index); self.shared_port.remove_task(index); GroupOutput::Shutdown(WebrtcOwner(index))