Skip to content

Commit

Permalink
WIP: check app is same or not when re-join
Browse files Browse the repository at this point in the history
  • Loading branch information
giangndm committed Sep 29, 2024
1 parent 3270a80 commit 7da74c2
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 13 deletions.
1 change: 1 addition & 0 deletions packages/transport_webrtc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ pub enum WebrtcError {
RpcEndpointNotFound = 0x2006,
RpcTokenInvalid = 0x2007,
RpcTokenRoomPeerNotMatch = 0x2008,
RpcTokenAppNotMatch = 0x2009,
}
3 changes: 2 additions & 1 deletion packages/transport_webrtc/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ pub struct TransportWebrtc<ES> {

impl<ES: 'static + MediaEdgeSecure> TransportWebrtc<ES> {
pub fn new(
app: AppContext,

Check warning on line 122 in packages/transport_webrtc/src/transport.rs

View check run for this annotation

Codecov / codecov/patch

packages/transport_webrtc/src/transport.rs#L122

Added line #L122 was not covered by tests
remote: IpAddr,
variant: VariantParams<ES>,
offer: &str,
Expand Down Expand Up @@ -164,7 +165,7 @@ impl<ES: 'static + MediaEdgeSecure> TransportWebrtc<ES> {
// we need to start sctp as client side for handling restart-ice in new server
// if not, datachannel will not connect successful after reconnect to new server
rtc.direct_api().start_sctp(true);
Box::new(webrtc::TransportWebrtcSdk::new(req, extra_data, secure, remote))
Box::new(webrtc::TransportWebrtcSdk::new(app, req, extra_data, secure, remote))

Check warning on line 168 in packages/transport_webrtc/src/transport.rs

View check run for this annotation

Codecov / codecov/patch

packages/transport_webrtc/src/transport.rs#L168

Added line #L168 was not covered by tests
}
};

Expand Down
94 changes: 86 additions & 8 deletions packages/transport_webrtc/src/transport/webrtc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use media_server_core::{
};
use media_server_protocol::{
endpoint::{AudioMixerConfig, PeerId, PeerMeta, RoomId, RoomInfoPublish, RoomInfoSubscribe},
multi_tenancy::AppContext,
protobuf::{
self,
features::{
Expand Down Expand Up @@ -82,6 +83,7 @@ enum TransportWebrtcError {
}

pub struct TransportWebrtcSdk<ES> {
app: AppContext,
remote: IpAddr,
extra_data: Option<String>,
join: Option<(RoomId, PeerId, Option<String>, RoomInfoPublish, RoomInfoSubscribe)>,
Expand All @@ -98,12 +100,13 @@ pub struct TransportWebrtcSdk<ES> {
}

impl<ES> TransportWebrtcSdk<ES> {
pub fn new(req: ConnectRequest, extra_data: Option<String>, secure: Arc<ES>, remote: IpAddr) -> Self {
pub fn new(app: AppContext, req: ConnectRequest, extra_data: Option<String>, secure: Arc<ES>, remote: IpAddr) -> Self {
let tracks = req.tracks.unwrap_or_default();
let local_tracks: Vec<LocalTrack> = tracks.receivers.into_iter().enumerate().map(|(index, r)| LocalTrack::new((index as u16).into(), r)).collect();
let remote_tracks: Vec<RemoteTrack> = tracks.senders.into_iter().enumerate().map(|(index, s)| RemoteTrack::new((index as u16).into(), s)).collect();
if let Some(j) = req.join {
Self {
app,
remote,
extra_data,
join: Some((j.room.into(), j.peer.into(), j.metadata, j.publish.unwrap_or_default().into(), j.subscribe.unwrap_or_default().into())),
Expand All @@ -130,6 +133,7 @@ impl<ES> TransportWebrtcSdk<ES> {
}
} else {
Self {
app,
remote,
extra_data,
join: None,
Expand Down Expand Up @@ -709,8 +713,10 @@ impl<ES: MediaEdgeSecure> TransportWebrtcSdk<ES> {
metadata: info.metadata,
extra_data: self.extra_data.clone(),
};
if let Some((_ctx, token)) = self.secure.decode_token::<WebrtcToken>(&req.token) {
if token.room == Some(info.room.clone()) && token.peer == Some(info.peer.clone()) {
if let Some((ctx, token)) = self.secure.decode_token::<WebrtcToken>(&req.token) {
if ctx.app != self.app.app {
self.send_rpc_res_err(req_id, RpcError::new2(WebrtcError::RpcTokenAppNotMatch));
} else if token.room == Some(info.room.clone()) && token.peer == Some(info.peer.clone()) {
let mixer_cfg = info.features.and_then(|f| {
f.mixer.map(|m| AudioMixerConfig {
mode: m.mode().into(),
Expand Down Expand Up @@ -893,21 +899,26 @@ mod tests {
};
use media_server_protocol::{
endpoint::{PeerMeta, RoomInfoPublish, RoomInfoSubscribe},
multi_tenancy::AppContext,
multi_tenancy::{AppContext, AppId},
protobuf::{
gateway,
self, gateway,
session::{self, client_event, ClientEvent},
shared,
},
tokens::WebrtcToken,
transport::RpcError,
};
use media_server_secure::{
jwt::{MediaEdgeSecureJwt, MediaGatewaySecureJwt},
MediaGatewaySecure,
};
use prost::Message;
use str0m::channel::ChannelId;

use crate::transport::{InternalOutput, TransportWebrtcInternal};
use crate::{
transport::{InternalOutput, TransportWebrtcInternal},
WebrtcError,
};

use super::TransportWebrtcSdk;

Expand All @@ -918,6 +929,7 @@ mod tests {

#[test]
fn join_room_first() {
let app = AppContext::root_app();
let req = gateway::ConnectRequest {
join: Some(session::RoomJoin {
room: "room".to_string(),
Expand All @@ -935,7 +947,7 @@ mod tests {
let now = Instant::now();
let ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
let secure_jwt = Arc::new(MediaEdgeSecureJwt::from(b"1234".as_slice()));
let mut transport = TransportWebrtcSdk::new(req, Some("extra_data".to_string()), secure_jwt.clone(), ip);
let mut transport = TransportWebrtcSdk::new(app, req, Some("extra_data".to_string()), secure_jwt.clone(), ip);
assert_eq!(transport.pop_output(now), None);

transport.on_tick(now);
Expand Down Expand Up @@ -971,6 +983,7 @@ mod tests {

#[test]
fn join_room_lazy() {
let app = AppContext::root_app();
let req = gateway::ConnectRequest::default();

let channel_id = create_channel_id();
Expand All @@ -979,7 +992,7 @@ mod tests {
let ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
let gateway_jwt = MediaGatewaySecureJwt::from(b"1234".as_slice());
let secure_jwt = Arc::new(MediaEdgeSecureJwt::from(b"1234".as_slice()));
let mut transport = TransportWebrtcSdk::new(req, Some("extra_data".to_string()), secure_jwt.clone(), ip);
let mut transport = TransportWebrtcSdk::new(app, req, Some("extra_data".to_string()), secure_jwt.clone(), ip);
assert_eq!(transport.pop_output(now), None);

transport.on_tick(now);
Expand Down Expand Up @@ -1045,6 +1058,71 @@ mod tests {
assert_eq!(transport.pop_output(now), None);
}

#[test]
fn join_room_lazy_wrong_app() {
let app = AppContext::root_app();
let req = gateway::ConnectRequest::default();

let channel_id = create_channel_id();

let now = Instant::now();
let ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
let gateway_jwt = MediaGatewaySecureJwt::from(b"1234".as_slice());
let secure_jwt = Arc::new(MediaEdgeSecureJwt::from(b"1234".as_slice()));
let mut transport = TransportWebrtcSdk::new(app, req, Some("extra_data".to_string()), secure_jwt.clone(), ip);
assert_eq!(transport.pop_output(now), None);

transport.on_tick(now);
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connecting(ip)))))
);

transport.on_str0m_event(now, str0m::Event::ChannelOpen(channel_id, "data".to_string()));
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected(ip)))))
);
assert_eq!(transport.pop_output(now), None);

let token = gateway_jwt.encode_token(
&AppContext { app: AppId::from("app1") },
WebrtcToken {
room: Some("demo".to_string()),
peer: Some("peer1".to_string()),
record: false,
extra_data: Some("extra_data".to_string()),
},
10000,
);
transport.on_str0m_channel_event(ClientEvent {
seq: 0,
event: Some(client_event::Event::Request(session::Request {
req_id: 1,
request: Some(session::request::Request::Session(session::request::Session {
request: Some(session::request::session::Request::Join(session::request::session::Join {
info: Some(session::RoomJoin {
room: "demo".to_string(),
peer: "peer1".to_string(),
metadata: None,
publish: None,
subscribe: None,
features: None,
}),
token: token.clone(),
})),
})),
})),
});

let response = protobuf::session::response::Response::Error(RpcError::new2(WebrtcError::RpcTokenAppNotMatch).into());
let event = protobuf::session::server_event::Event::Response(protobuf::session::Response { req_id: 1, response: Some(response) });
let event_buf = protobuf::session::ServerEvent { seq: 0, event: Some(event) }.encode_to_vec();

assert_eq!(transport.pop_output(now), Some(InternalOutput::Str0mSendData(channel_id, event_buf)));
assert_eq!(transport.pop_output(now), None);
}

//TODO test remote track non-source
//TODO test remote track with source
//TODO test remote track attach, detach
Expand Down
8 changes: 4 additions & 4 deletions packages/transport_webrtc/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,25 @@ impl<ES: MediaEdgeSecure> MediaWorkerWebrtc<ES> {
pub fn spawn(&mut self, app: AppContext, remote: IpAddr, session_id: u64, variant: VariantParams<ES>, offer: &str) -> RpcResult<(bool, String, usize)> {

Check warning on line 78 in packages/transport_webrtc/src/worker.rs

View check run for this annotation

Codecov / codecov/patch

packages/transport_webrtc/src/worker.rs#L78

Added line #L78 was not covered by tests
let cfg = match &variant {
VariantParams::Whip(_, _, _, record) => EndpointCfg {
app,
app: app.clone(),

Check warning on line 81 in packages/transport_webrtc/src/worker.rs

View check run for this annotation

Codecov / codecov/patch

packages/transport_webrtc/src/worker.rs#L81

Added line #L81 was not covered by tests
max_ingress_bitrate: 2_500_000,
max_egress_bitrate: 2_500_000,
record: *record,
},
VariantParams::Whep(_, _, _) => EndpointCfg {
app,
app: app.clone(),

Check warning on line 87 in packages/transport_webrtc/src/worker.rs

View check run for this annotation

Codecov / codecov/patch

packages/transport_webrtc/src/worker.rs#L87

Added line #L87 was not covered by tests
max_ingress_bitrate: 2_500_000,
max_egress_bitrate: 2_500_000,
record: false,
},
VariantParams::Webrtc(_, _, _, record, _) => EndpointCfg {
app,
app: app.clone(),

Check warning on line 93 in packages/transport_webrtc/src/worker.rs

View check run for this annotation

Codecov / codecov/patch

packages/transport_webrtc/src/worker.rs#L93

Added line #L93 was not covered by tests
max_ingress_bitrate: 2_500_000,
max_egress_bitrate: 2_500_000,
record: *record,
},
};
let (tran, ufrag, sdp) = TransportWebrtc::new(remote, variant, offer, self.dtls_cert.clone(), &self.addrs, &self.addrs_alt, self.ice_lite)?;
let (tran, ufrag, sdp) = TransportWebrtc::new(app, remote, variant, offer, self.dtls_cert.clone(), &self.addrs, &self.addrs_alt, self.ice_lite)?;

Check warning on line 99 in packages/transport_webrtc/src/worker.rs

View check run for this annotation

Codecov / codecov/patch

packages/transport_webrtc/src/worker.rs#L99

Added line #L99 was not covered by tests
log::info!("[TransportWebrtc] create endpoint with config {:?}", cfg);
let endpoint = Endpoint::new(session_id, cfg, tran);
let index = self.endpoints.add_task(endpoint);
Expand Down

0 comments on commit 7da74c2

Please sign in to comment.