Skip to content

Commit

Permalink
fix: webrtc transport stuck on connect_error cause memory leak (8xFF#453
Browse files Browse the repository at this point in the history
)
  • Loading branch information
giangndm authored Nov 14, 2024
1 parent 599e2ad commit 887dcde
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 10 deletions.
1 change: 1 addition & 0 deletions packages/media_core/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ where
type Time = Instant;

fn is_empty(&self) -> bool {
// we don't need to check shutdown here, because it can be shutdown by transport itself
self.internal.is_empty() && self.transport.is_empty()
}

Expand Down
99 changes: 93 additions & 6 deletions packages/transport_webrtc/src/transport/webrtc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ enum State {
Disconnected,
}

impl State {
fn is_shutdown(&self) -> bool {
matches!(self, State::ConnectError(_) | State::Disconnected)
}
}

#[derive(Debug)]
enum TransportWebrtcError {
Timeout,
Expand Down Expand Up @@ -199,7 +205,7 @@ impl<ES: MediaEdgeSecure> TransportWebrtcInternal for TransportWebrtcSdk<ES> {
}

fn is_empty(&self) -> bool {
matches!(self.state, State::Disconnected) && self.queue.is_empty()
self.state.is_shutdown() && self.queue.is_empty()
}

fn on_tick(&mut self, now: Instant) {
Expand Down Expand Up @@ -452,7 +458,6 @@ impl<ES: MediaEdgeSecure> TransportWebrtcInternal for TransportWebrtcSdk<ES> {
response: Some(MessageChannelResponse::Sub(Subscribe {})),
}),
),

EndpointMessageChannelRes::Unsubscribe(Ok(_)) => self.send_rpc_res(
req_id.0,
media_server_protocol::protobuf::session::response::Response::MessageChannel(MessageChannel {
Expand Down Expand Up @@ -591,7 +596,7 @@ impl<ES: MediaEdgeSecure> TransportWebrtcInternal for TransportWebrtcSdk<ES> {
}

fn on_shutdown(&mut self, _now: Instant) {
if !matches!(self.state, State::Disconnected) {
if !self.state.is_shutdown() {
log::info!("[TransportWebrtcSdk] switched to disconnected with close action");
self.state = State::Disconnected;
self.queue
Expand Down Expand Up @@ -907,12 +912,12 @@ mod tests {
use std::{
net::{IpAddr, Ipv4Addr},
sync::Arc,
time::Instant,
time::{Duration, Instant},
};

use media_server_core::{
endpoint::EndpointReq,
transport::{TransportEvent, TransportOutput, TransportState},
transport::{TransportError, TransportEvent, TransportOutput, TransportState},
};
use media_server_protocol::{
endpoint::{PeerMeta, RoomInfoPublish, RoomInfoSubscribe},
Expand All @@ -933,7 +938,7 @@ mod tests {
use str0m::channel::ChannelId;

use crate::{
transport::{InternalOutput, TransportWebrtcInternal},
transport::{webrtc::TIMEOUT_SEC, InternalOutput, TransportWebrtcInternal},
WebrtcError,
};

Expand Down Expand Up @@ -1140,6 +1145,88 @@ mod tests {
assert_eq!(transport.pop_output(now), None);
}

#[test]
fn connect_error_shutdown() {
let app = AppContext::root_app();
let req = gateway::ConnectRequest::default();

let now = Instant::now();
let ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
let secure_jwt = Arc::new(MediaEdgeSecureJwt::from(b"1234".as_slice()));
let mut transport = TransportWebrtcSdk::new(app, req, Some("extra_data".to_string()), secure_jwt.clone(), ip);
assert_eq!(transport.pop_output(now), None);

transport.on_tick(now);
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connecting(ip)))))
);

transport.on_tick(now + Duration::from_secs(TIMEOUT_SEC));
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::ConnectError(
TransportError::Timeout
)))))
);
assert_eq!(transport.pop_output(now), None);
assert!(transport.is_empty());
}

#[test]
fn shutdown_before_connected() {
let app = AppContext::root_app();
let req = gateway::ConnectRequest::default();

let now = Instant::now();
let ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
let secure_jwt = Arc::new(MediaEdgeSecureJwt::from(b"1234".as_slice()));
let mut transport = TransportWebrtcSdk::new(app, req, Some("extra_data".to_string()), secure_jwt.clone(), ip);
assert_eq!(transport.pop_output(now), None);

transport.on_tick(now);
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connecting(ip)))))
);

transport.on_shutdown(now);
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Disconnected(None)))))
);
assert_eq!(transport.pop_output(now), None);
assert!(transport.is_empty());
}

#[test]
fn shutdown_after_connected() {
let app = AppContext::root_app();
let req = gateway::ConnectRequest::default();

let channel_id = create_channel_id();

let now = Instant::now();
let ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
let secure_jwt = Arc::new(MediaEdgeSecureJwt::from(b"1234".as_slice()));
let mut transport = TransportWebrtcSdk::new(app, req, Some("extra_data".to_string()), secure_jwt.clone(), ip);
assert_eq!(transport.pop_output(now), None);

transport.on_str0m_event(now, str0m::Event::ChannelOpen(channel_id, "data".to_string()));
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected(ip)))))
);

transport.on_shutdown(now);
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Disconnected(None)))))
);
assert_eq!(transport.pop_output(now), None);
assert!(transport.is_empty());
}

//TODO test remote track non-source
//TODO test remote track with source
//TODO test remote track attach, detach
Expand Down
115 changes: 112 additions & 3 deletions packages/transport_webrtc/src/transport/whep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ enum State {
Disconnected,
}

impl State {
fn is_shutdown(&self) -> bool {
matches!(self, State::ConnectError(_) | State::Disconnected)
}
}

#[derive(Debug)]
enum TransportWebrtcError {
Timeout,
Expand Down Expand Up @@ -86,7 +92,7 @@ impl TransportWebrtcInternal for TransportWebrtcWhep {
fn on_codec_config(&mut self, _cfg: &str0m::format::CodecConfig) {}

fn is_empty(&self) -> bool {
matches!(self.state, State::Disconnected) && self.queue.is_empty()
self.state.is_shutdown() && self.queue.is_empty()
}

fn on_tick(&mut self, now: Instant) {
Expand Down Expand Up @@ -176,6 +182,8 @@ impl TransportWebrtcInternal for TransportWebrtcWhep {
Str0mEvent::Connected => {
log::info!("[TransportWebrtcWhep] connected");
self.state = State::Connected;
self.queue
.push_back(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected(self.remote)))));
self.queue.push_back(InternalOutput::TransportOutput(TransportOutput::RpcReq(
0.into(),
EndpointReq::JoinRoom(
Expand All @@ -190,8 +198,6 @@ impl TransportWebrtcInternal for TransportWebrtcWhep {
None,
),
)));
self.queue
.push_back(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected(self.remote)))));
}
Str0mEvent::IceConnectionStateChange(state) => self.on_str0m_state(now, state),
Str0mEvent::MediaAdded(media) => self.on_str0m_media_added(now, media),
Expand Down Expand Up @@ -377,3 +383,106 @@ impl TransportWebrtcWhep {
}
}
}

#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;

use super::*;

#[test]
fn shutdown_before_connected() {
let ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
let now = Instant::now();
let mut transport = TransportWebrtcWhep::new("room".into(), "peer".into(), None, ip);
assert_eq!(transport.pop_output(now), None);

transport.on_tick(now);
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connecting(ip)))))
);
assert_eq!(transport.pop_output(now), None);

transport.on_shutdown(now);
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Disconnected(None)))))
);
assert_eq!(transport.pop_output(now), None);
assert!(transport.is_empty());
}

#[test]
fn shutdown_after_connected() {
let ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
let now = Instant::now();
let room: RoomId = "room".into();
let peer: PeerId = "peer".into();

let mut transport = TransportWebrtcWhep::new(room.clone(), peer.clone(), None, ip);
assert_eq!(transport.pop_output(now), None);

transport.on_tick(now);
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connecting(ip)))))
);

transport.on_str0m_event(now, str0m::Event::Connected);
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connected(ip)))))
);
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::RpcReq(
0.into(),
EndpointReq::JoinRoom(
room.clone(),
peer.clone(),
PeerMeta { metadata: None, extra_data: None },
RoomInfoPublish { peer: false, tracks: false },
RoomInfoSubscribe { peers: false, tracks: true },
None,
),
)))
);

transport.on_shutdown(now);
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Disconnected(None)))))
);
assert_eq!(transport.pop_output(now), None);
assert!(transport.is_empty());
}

#[test]
fn shutdown_after_connect_error() {
let ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
let now = Instant::now();
let room: RoomId = "room".into();
let peer: PeerId = "peer".into();

let mut transport = TransportWebrtcWhep::new(room.clone(), peer.clone(), None, ip);
assert_eq!(transport.pop_output(now), None);

transport.on_tick(now);
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::Connecting(ip)))))
);

transport.on_tick(now + Duration::from_secs(TIMEOUT_SEC));
assert_eq!(
transport.pop_output(now),
Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::State(TransportState::ConnectError(
TransportError::Timeout
)))))
);

assert_eq!(transport.pop_output(now), None);
assert!(transport.is_empty());
}
}
Loading

0 comments on commit 887dcde

Please sign in to comment.