From 7da74c2c5054dcca50b9ee830d0867efc6e67ba8 Mon Sep 17 00:00:00 2001 From: Giang Minh Date: Sun, 29 Sep 2024 22:29:33 +0700 Subject: [PATCH] WIP: check app is same or not when re-join --- packages/transport_webrtc/src/lib.rs | 1 + packages/transport_webrtc/src/transport.rs | 3 +- .../transport_webrtc/src/transport/webrtc.rs | 94 +++++++++++++++++-- packages/transport_webrtc/src/worker.rs | 8 +- 4 files changed, 93 insertions(+), 13 deletions(-) diff --git a/packages/transport_webrtc/src/lib.rs b/packages/transport_webrtc/src/lib.rs index 75e4b176..e00a6457 100644 --- a/packages/transport_webrtc/src/lib.rs +++ b/packages/transport_webrtc/src/lib.rs @@ -18,4 +18,5 @@ pub enum WebrtcError { RpcEndpointNotFound = 0x2006, RpcTokenInvalid = 0x2007, RpcTokenRoomPeerNotMatch = 0x2008, + RpcTokenAppNotMatch = 0x2009, } diff --git a/packages/transport_webrtc/src/transport.rs b/packages/transport_webrtc/src/transport.rs index 5980862e..3ef7e189 100644 --- a/packages/transport_webrtc/src/transport.rs +++ b/packages/transport_webrtc/src/transport.rs @@ -119,6 +119,7 @@ pub struct TransportWebrtc { impl TransportWebrtc { pub fn new( + app: AppContext, remote: IpAddr, variant: VariantParams, offer: &str, @@ -164,7 +165,7 @@ impl TransportWebrtc { // we need to start sctp as client side for handling restart-ice in new server // if not, datachannel will not connect successful after reconnect to new server rtc.direct_api().start_sctp(true); - Box::new(webrtc::TransportWebrtcSdk::new(req, extra_data, secure, remote)) + Box::new(webrtc::TransportWebrtcSdk::new(app, req, extra_data, secure, remote)) } }; diff --git a/packages/transport_webrtc/src/transport/webrtc.rs b/packages/transport_webrtc/src/transport/webrtc.rs index c2f46a3e..9ca4ca9a 100644 --- a/packages/transport_webrtc/src/transport/webrtc.rs +++ b/packages/transport_webrtc/src/transport/webrtc.rs @@ -13,6 +13,7 @@ use media_server_core::{ }; use media_server_protocol::{ endpoint::{AudioMixerConfig, PeerId, PeerMeta, RoomId, RoomInfoPublish, RoomInfoSubscribe}, + multi_tenancy::AppContext, protobuf::{ self, features::{ @@ -82,6 +83,7 @@ enum TransportWebrtcError { } pub struct TransportWebrtcSdk { + app: AppContext, remote: IpAddr, extra_data: Option, join: Option<(RoomId, PeerId, Option, RoomInfoPublish, RoomInfoSubscribe)>, @@ -98,12 +100,13 @@ pub struct TransportWebrtcSdk { } impl TransportWebrtcSdk { - pub fn new(req: ConnectRequest, extra_data: Option, secure: Arc, remote: IpAddr) -> Self { + pub fn new(app: AppContext, req: ConnectRequest, extra_data: Option, secure: Arc, remote: IpAddr) -> Self { let tracks = req.tracks.unwrap_or_default(); let local_tracks: Vec = tracks.receivers.into_iter().enumerate().map(|(index, r)| LocalTrack::new((index as u16).into(), r)).collect(); let remote_tracks: Vec = tracks.senders.into_iter().enumerate().map(|(index, s)| RemoteTrack::new((index as u16).into(), s)).collect(); if let Some(j) = req.join { Self { + app, remote, extra_data, join: Some((j.room.into(), j.peer.into(), j.metadata, j.publish.unwrap_or_default().into(), j.subscribe.unwrap_or_default().into())), @@ -130,6 +133,7 @@ impl TransportWebrtcSdk { } } else { Self { + app, remote, extra_data, join: None, @@ -709,8 +713,10 @@ impl TransportWebrtcSdk { metadata: info.metadata, extra_data: self.extra_data.clone(), }; - if let Some((_ctx, token)) = self.secure.decode_token::(&req.token) { - if token.room == Some(info.room.clone()) && token.peer == Some(info.peer.clone()) { + if let Some((ctx, token)) = self.secure.decode_token::(&req.token) { + if ctx.app != self.app.app { + self.send_rpc_res_err(req_id, RpcError::new2(WebrtcError::RpcTokenAppNotMatch)); + } else if token.room == Some(info.room.clone()) && token.peer == Some(info.peer.clone()) { let mixer_cfg = info.features.and_then(|f| { f.mixer.map(|m| AudioMixerConfig { mode: m.mode().into(), @@ -893,21 +899,26 @@ mod tests { }; use media_server_protocol::{ endpoint::{PeerMeta, RoomInfoPublish, RoomInfoSubscribe}, - multi_tenancy::AppContext, + multi_tenancy::{AppContext, AppId}, protobuf::{ - gateway, + self, gateway, session::{self, client_event, ClientEvent}, shared, }, tokens::WebrtcToken, + transport::RpcError, }; use media_server_secure::{ jwt::{MediaEdgeSecureJwt, MediaGatewaySecureJwt}, MediaGatewaySecure, }; + use prost::Message; use str0m::channel::ChannelId; - use crate::transport::{InternalOutput, TransportWebrtcInternal}; + use crate::{ + transport::{InternalOutput, TransportWebrtcInternal}, + WebrtcError, + }; use super::TransportWebrtcSdk; @@ -918,6 +929,7 @@ mod tests { #[test] fn join_room_first() { + let app = AppContext::root_app(); let req = gateway::ConnectRequest { join: Some(session::RoomJoin { room: "room".to_string(), @@ -935,7 +947,7 @@ mod tests { let now = Instant::now(); let ip = IpAddr::V4(Ipv4Addr::LOCALHOST); let secure_jwt = Arc::new(MediaEdgeSecureJwt::from(b"1234".as_slice())); - let mut transport = TransportWebrtcSdk::new(req, Some("extra_data".to_string()), secure_jwt.clone(), ip); + let mut transport = TransportWebrtcSdk::new(app, req, Some("extra_data".to_string()), secure_jwt.clone(), ip); assert_eq!(transport.pop_output(now), None); transport.on_tick(now); @@ -971,6 +983,7 @@ mod tests { #[test] fn join_room_lazy() { + let app = AppContext::root_app(); let req = gateway::ConnectRequest::default(); let channel_id = create_channel_id(); @@ -979,7 +992,7 @@ mod tests { let ip = IpAddr::V4(Ipv4Addr::LOCALHOST); let gateway_jwt = MediaGatewaySecureJwt::from(b"1234".as_slice()); let secure_jwt = Arc::new(MediaEdgeSecureJwt::from(b"1234".as_slice())); - let mut transport = TransportWebrtcSdk::new(req, Some("extra_data".to_string()), secure_jwt.clone(), ip); + let mut transport = TransportWebrtcSdk::new(app, req, Some("extra_data".to_string()), secure_jwt.clone(), ip); assert_eq!(transport.pop_output(now), None); transport.on_tick(now); @@ -1045,6 +1058,71 @@ mod tests { assert_eq!(transport.pop_output(now), None); } + #[test] + fn join_room_lazy_wrong_app() { + let app = AppContext::root_app(); + let req = gateway::ConnectRequest::default(); + + let channel_id = create_channel_id(); + + let now = Instant::now(); + let ip = IpAddr::V4(Ipv4Addr::LOCALHOST); + let gateway_jwt = MediaGatewaySecureJwt::from(b"1234".as_slice()); + let secure_jwt = Arc::new(MediaEdgeSecureJwt::from(b"1234".as_slice())); + let mut transport = TransportWebrtcSdk::new(app, req, Some("extra_data".to_string()), secure_jwt.clone(), ip); + assert_eq!(transport.pop_output(now), None); + + transport.on_tick(now); + assert_eq!( + transport.pop_output(now), + Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connecting(ip))))) + ); + + transport.on_str0m_event(now, str0m::Event::ChannelOpen(channel_id, "data".to_string())); + assert_eq!( + transport.pop_output(now), + Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected(ip))))) + ); + assert_eq!(transport.pop_output(now), None); + + let token = gateway_jwt.encode_token( + &AppContext { app: AppId::from("app1") }, + WebrtcToken { + room: Some("demo".to_string()), + peer: Some("peer1".to_string()), + record: false, + extra_data: Some("extra_data".to_string()), + }, + 10000, + ); + transport.on_str0m_channel_event(ClientEvent { + seq: 0, + event: Some(client_event::Event::Request(session::Request { + req_id: 1, + request: Some(session::request::Request::Session(session::request::Session { + request: Some(session::request::session::Request::Join(session::request::session::Join { + info: Some(session::RoomJoin { + room: "demo".to_string(), + peer: "peer1".to_string(), + metadata: None, + publish: None, + subscribe: None, + features: None, + }), + token: token.clone(), + })), + })), + })), + }); + + let response = protobuf::session::response::Response::Error(RpcError::new2(WebrtcError::RpcTokenAppNotMatch).into()); + let event = protobuf::session::server_event::Event::Response(protobuf::session::Response { req_id: 1, response: Some(response) }); + let event_buf = protobuf::session::ServerEvent { seq: 0, event: Some(event) }.encode_to_vec(); + + assert_eq!(transport.pop_output(now), Some(InternalOutput::Str0mSendData(channel_id, event_buf))); + assert_eq!(transport.pop_output(now), None); + } + //TODO test remote track non-source //TODO test remote track with source //TODO test remote track attach, detach diff --git a/packages/transport_webrtc/src/worker.rs b/packages/transport_webrtc/src/worker.rs index 30874075..1f7b0a28 100644 --- a/packages/transport_webrtc/src/worker.rs +++ b/packages/transport_webrtc/src/worker.rs @@ -78,25 +78,25 @@ impl MediaWorkerWebrtc { pub fn spawn(&mut self, app: AppContext, remote: IpAddr, session_id: u64, variant: VariantParams, offer: &str) -> RpcResult<(bool, String, usize)> { let cfg = match &variant { VariantParams::Whip(_, _, _, record) => EndpointCfg { - app, + app: app.clone(), max_ingress_bitrate: 2_500_000, max_egress_bitrate: 2_500_000, record: *record, }, VariantParams::Whep(_, _, _) => EndpointCfg { - app, + app: app.clone(), max_ingress_bitrate: 2_500_000, max_egress_bitrate: 2_500_000, record: false, }, VariantParams::Webrtc(_, _, _, record, _) => EndpointCfg { - app, + app: app.clone(), max_ingress_bitrate: 2_500_000, max_egress_bitrate: 2_500_000, record: *record, }, }; - let (tran, ufrag, sdp) = TransportWebrtc::new(remote, variant, offer, self.dtls_cert.clone(), &self.addrs, &self.addrs_alt, self.ice_lite)?; + let (tran, ufrag, sdp) = TransportWebrtc::new(app, remote, variant, offer, self.dtls_cert.clone(), &self.addrs, &self.addrs_alt, self.ice_lite)?; log::info!("[TransportWebrtc] create endpoint with config {:?}", cfg); let endpoint = Endpoint::new(session_id, cfg, tran); let index = self.endpoints.add_task(endpoint);