Skip to content

Commit

Permalink
feat: auto or manual peer info subscribe (8xFF#135)
Browse files Browse the repository at this point in the history
This PR introduces a room scope mode for each endpoint. There are two modes:

- Auto: All peer and track events will be fired to the endpoint.
- Manual: Only subscribed peers will trigger events at the endpoint.

This feature is particularly useful for creating an online event platform like Gather.town. We're only interested in certain users nearby. When a user is close, we'll call room.subscribe(peer_id), and when they're far away, we'll call room.unsubscribe(peer_id).
  • Loading branch information
giangndm authored Jan 1, 2024
1 parent 2129174 commit 68a6f4f
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 62 deletions.
46 changes: 3 additions & 43 deletions packages/endpoint/src/endpoint_wrap.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;

use async_std::stream::StreamExt;
use cluster::{ClusterEndpoint, EndpointSubscribeScope, MixMinusAudioMode};
Expand Down Expand Up @@ -37,8 +37,6 @@ where
cluster: C,
tick: async_std::stream::Interval,
timer: Arc<dyn Timer>,
sub_scope: EndpointSubscribeScope,
peer_subscribe: HashMap<String, ()>,
}

impl<T, E, C> MediaEndpoint<T, E, C>
Expand All @@ -48,7 +46,7 @@ where
{
pub fn new(
transport: T,
mut cluster: C,
cluster: C,
room: &str,
peer: &str,
sub_scope: EndpointSubscribeScope,
Expand All @@ -57,12 +55,6 @@ where
mix_minus_size: usize,
) -> Self {
log::info!("[EndpointWrap] create");
//TODO handle error of cluster sub room
if matches!(sub_scope, EndpointSubscribeScope::RoomAuto) {
if let Err(_e) = cluster.on_event(cluster::ClusterEndpointOutgoingEvent::SubscribeRoom) {
todo!("handle error")
}
}
let timer = Arc::new(media_utils::SystemTimer());
let middlewares: Vec<Box<dyn MediaEndpointMiddleware>> = vec![
Box::new(middleware::logger::MediaEndpointEventLogger::new()),
Expand All @@ -74,7 +66,7 @@ where
mix_minus_size,
)),
];
let mut internal = MediaEndpointInternal::new(room, peer, bitrate_type, middlewares);
let mut internal = MediaEndpointInternal::new(room, peer, sub_scope, bitrate_type, middlewares);
internal.on_start(timer.now_ms());

Self {
Expand All @@ -84,8 +76,6 @@ where
cluster,
tick: async_std::stream::interval(std::time::Duration::from_millis(100)),
timer,
sub_scope,
peer_subscribe: HashMap::new(),
}
}

Expand All @@ -106,22 +96,6 @@ where
MediaEndpointInternalEvent::ConnectionError(e) => {
return Err(e);
}
MediaEndpointInternalEvent::SubscribePeer(peer) => {
if matches!(self.sub_scope, EndpointSubscribeScope::RoomManual) {
self.peer_subscribe.insert(peer.clone(), ());
if let Err(_e) = self.cluster.on_event(cluster::ClusterEndpointOutgoingEvent::SubscribePeer(peer)) {
todo!("handle error")
}
}
}
MediaEndpointInternalEvent::UnsubscribePeer(peer) => {
if matches!(self.sub_scope, EndpointSubscribeScope::RoomManual) {
self.peer_subscribe.remove(&peer);
if let Err(_e) = self.cluster.on_event(cluster::ClusterEndpointOutgoingEvent::UnsubscribePeer(peer)) {
todo!("handle error")
}
}
}
},
MediaInternalAction::Endpoint(e) => {
if let Err(e) = self.transport.on_event(self.timer.now_ms(), e) {
Expand Down Expand Up @@ -176,20 +150,6 @@ where
{
fn drop(&mut self) {
log::info!("[EndpointWrap] drop");
match self.sub_scope {
EndpointSubscribeScope::RoomAuto => {
if let Err(_e) = self.cluster.on_event(cluster::ClusterEndpointOutgoingEvent::UnsubscribeRoom) {
todo!("handle error")
}
}
EndpointSubscribeScope::RoomManual => {
for peer in self.peer_subscribe.keys() {
if let Err(_e) = self.cluster.on_event(cluster::ClusterEndpointOutgoingEvent::UnsubscribePeer(peer.clone())) {
todo!("handle error")
}
}
}
}
self.internal.before_drop(self.timer.now_ms());
while let Some(out) = self.internal.pop_action() {
match out {
Expand Down
164 changes: 146 additions & 18 deletions packages/endpoint/src/endpoint_wrap/internal.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::collections::{HashMap, VecDeque};

use cluster::{ClusterEndpointIncomingEvent, ClusterEndpointOutgoingEvent};
use cluster::{ClusterEndpointIncomingEvent, ClusterEndpointOutgoingEvent, EndpointSubscribeScope};
use media_utils::hash_str;
use transport::{MediaKind, TrackId, TransportError, TransportIncomingEvent, TransportOutgoingEvent, TransportStateEvent};

use crate::{
middleware::MediaEndpointMiddleware,
rpc::{EndpointRpcIn, EndpointRpcOut, LocalTrackRpcIn, LocalTrackRpcOut, RemoteTrackRpcIn, RemoteTrackRpcOut, TrackInfo},
MediaEndpointMiddlewareOutput,
MediaEndpointMiddlewareOutput, RpcResponse,
};

use self::{
Expand All @@ -31,8 +31,6 @@ pub enum MediaEndpointInternalEvent {
ConnectionClosed,
ConnectionCloseRequest,
ConnectionError(TransportError),
SubscribePeer(String),
UnsubscribePeer(String),
}

#[derive(Debug, PartialEq, Eq)]
Expand All @@ -45,6 +43,7 @@ pub enum MediaInternalAction {
pub struct MediaEndpointInternal {
room_id: String,
peer_id: String,
sub_scope: EndpointSubscribeScope,
cluster_track_map: HashMap<(String, String), MediaKind>,
local_track_map: HashMap<String, TrackId>,
output_actions: VecDeque<MediaInternalAction>,
Expand All @@ -53,14 +52,16 @@ pub struct MediaEndpointInternal {
bitrate_allocator: bitrate_allocator::BitrateAllocator,
bitrate_limiter: bitrate_limiter::BitrateLimiter,
middlewares: Vec<Box<dyn MediaEndpointMiddleware>>,
subscribe_peers: HashMap<String, ()>,
}

impl MediaEndpointInternal {
pub fn new(room_id: &str, peer_id: &str, bitrate_limiter: BitrateLimiterType, middlewares: Vec<Box<dyn MediaEndpointMiddleware>>) -> Self {
pub fn new(room_id: &str, peer_id: &str, sub_scope: EndpointSubscribeScope, bitrate_limiter: BitrateLimiterType, middlewares: Vec<Box<dyn MediaEndpointMiddleware>>) -> Self {
log::info!("[MediaEndpointInternal {}/{}] create", room_id, peer_id);
Self {
room_id: room_id.into(),
peer_id: peer_id.into(),
sub_scope,
cluster_track_map: HashMap::new(),
local_track_map: HashMap::new(),
output_actions: VecDeque::with_capacity(100),
Expand All @@ -69,6 +70,7 @@ impl MediaEndpointInternal {
bitrate_allocator: bitrate_allocator::BitrateAllocator::new(DEFAULT_BITRATE_OUT_BPS),
bitrate_limiter: bitrate_limiter::BitrateLimiter::new(bitrate_limiter, MAX_BITRATE_IN_BPS),
middlewares,
subscribe_peers: HashMap::new(),
}
}

Expand Down Expand Up @@ -100,6 +102,10 @@ impl MediaEndpointInternal {
}

pub fn on_start(&mut self, now_ms: u64) {
if matches!(self.sub_scope, EndpointSubscribeScope::RoomAuto) {
self.push_cluster(ClusterEndpointOutgoingEvent::SubscribeRoom);
}

for middleware in self.middlewares.iter_mut() {
middleware.on_start(now_ms);
}
Expand Down Expand Up @@ -302,11 +308,28 @@ impl MediaEndpointInternal {
EndpointRpcIn::PeerClose => {
self.push_internal(MediaEndpointInternalEvent::ConnectionCloseRequest);
}
EndpointRpcIn::SubscribePeer(peer) => {}
EndpointRpcIn::UnsubscribePeer(peer) => {}
EndpointRpcIn::MixMinusSourceAdd(_) => todo!(),
EndpointRpcIn::MixMinusSourceRemove(_) => todo!(),
EndpointRpcIn::MixMinusToggle(_) => todo!(),
EndpointRpcIn::SubscribePeer(req) => {
if matches!(self.sub_scope, EndpointSubscribeScope::RoomManual) {
if !self.subscribe_peers.contains_key(&req.data.peer) {
self.subscribe_peers.insert(req.data.peer.clone(), ());
self.push_cluster(ClusterEndpointOutgoingEvent::SubscribePeer(req.data.peer));
}
self.push_rpc(EndpointRpcOut::SubscribePeerRes(RpcResponse::success(req.req_id, true)));
} else {
self.push_rpc(EndpointRpcOut::SubscribePeerRes(RpcResponse::error(req.req_id)));
}
}
EndpointRpcIn::UnsubscribePeer(req) => {
if matches!(self.sub_scope, EndpointSubscribeScope::RoomManual) {
if self.subscribe_peers.remove(&req.data.peer).is_some() {
self.push_cluster(ClusterEndpointOutgoingEvent::UnsubscribePeer(req.data.peer));
}
self.push_rpc(EndpointRpcOut::UnsubscribePeerRes(RpcResponse::success(req.req_id, true)));
} else {
self.push_rpc(EndpointRpcOut::UnsubscribePeerRes(RpcResponse::error(req.req_id)));
}
}
_ => {}
}
}

Expand Down Expand Up @@ -473,6 +496,18 @@ impl MediaEndpointInternal {
/// This should be called when the endpoint is closed
/// - Close all tracks
pub fn before_drop(&mut self, now_ms: u64) {
match self.sub_scope {
EndpointSubscribeScope::RoomAuto => {
self.push_cluster(ClusterEndpointOutgoingEvent::UnsubscribeRoom);
}
EndpointSubscribeScope::RoomManual => {
let peer_subscribe = std::mem::take(&mut self.subscribe_peers);
for peer in peer_subscribe.into_keys() {
self.push_cluster(ClusterEndpointOutgoingEvent::UnsubscribePeer(peer));
}
}
}

let local_tracks = std::mem::take(&mut self.local_tracks);
for (track_id, mut track) in local_tracks {
log::info!("[MediaEndpointInternal {}/{}] close local track {}", self.room_id, self.peer_id, track_id);
Expand Down Expand Up @@ -504,22 +539,24 @@ impl Drop for MediaEndpointInternal {

#[cfg(test)]
mod tests {
use cluster::{ClusterEndpointIncomingEvent, ClusterEndpointOutgoingEvent, ClusterLocalTrackOutgoingEvent, ClusterRemoteTrackOutgoingEvent, ClusterTrackMeta, ClusterTrackUuid};
use cluster::{
ClusterEndpointIncomingEvent, ClusterEndpointOutgoingEvent, ClusterLocalTrackOutgoingEvent, ClusterRemoteTrackOutgoingEvent, ClusterTrackMeta, ClusterTrackUuid, EndpointSubscribeScope,
};
use transport::{
LocalTrackIncomingEvent, LocalTrackOutgoingEvent, MediaPacket, RemoteTrackIncomingEvent, RequestKeyframeKind, TrackMeta, TransportIncomingEvent, TransportOutgoingEvent, TransportStateEvent,
};

use crate::{
endpoint_wrap::internal::{bitrate_limiter::BitrateLimiterType, MediaEndpointInternalEvent, MediaInternalAction, DEFAULT_BITRATE_OUT_BPS},
rpc::{LocalTrackRpcIn, LocalTrackRpcOut, ReceiverSwitch, RemoteStream, TrackInfo},
EndpointRpcOut, RpcRequest, RpcResponse,
rpc::{LocalTrackRpcIn, LocalTrackRpcOut, ReceiverSwitch, RemotePeer, RemoteStream, TrackInfo},
EndpointRpcIn, EndpointRpcOut, RpcRequest, RpcResponse,
};

use super::MediaEndpointInternal;

#[test]
fn should_fire_cluster_when_remote_track_added_then_close() {
let mut endpoint = MediaEndpointInternal::new("room1", "peer1", BitrateLimiterType::DynamicWithConsumers, vec![]);
let mut endpoint = MediaEndpointInternal::new("room1", "peer1", EndpointSubscribeScope::RoomManual, BitrateLimiterType::DynamicWithConsumers, vec![]);

let cluster_track_uuid = ClusterTrackUuid::from_info("room1", "peer1", "audio_main");
endpoint.on_transport(0, TransportIncomingEvent::RemoteTrackAdded("audio_main".to_string(), 100, TrackMeta::new_audio(None)));
Expand Down Expand Up @@ -563,7 +600,7 @@ mod tests {

#[test]
fn should_fire_cluster_when_remote_track_added_then_removed() {
let mut endpoint = MediaEndpointInternal::new("room1", "peer1", BitrateLimiterType::DynamicWithConsumers, vec![]);
let mut endpoint = MediaEndpointInternal::new("room1", "peer1", EndpointSubscribeScope::RoomManual, BitrateLimiterType::DynamicWithConsumers, vec![]);

let cluster_track_uuid = ClusterTrackUuid::from_info("room1", "peer1", "audio_main");
endpoint.on_transport(0, TransportIncomingEvent::RemoteTrackAdded("audio_main".to_string(), 100, TrackMeta::new_audio(None)));
Expand Down Expand Up @@ -606,7 +643,7 @@ mod tests {

#[test]
fn should_fire_rpc_when_cluster_track_added() {
let mut endpoint = MediaEndpointInternal::new("room1", "peer1", BitrateLimiterType::DynamicWithConsumers, vec![]);
let mut endpoint = MediaEndpointInternal::new("room1", "peer1", EndpointSubscribeScope::RoomManual, BitrateLimiterType::DynamicWithConsumers, vec![]);

endpoint.on_cluster(
0,
Expand Down Expand Up @@ -638,7 +675,7 @@ mod tests {

#[test]
fn should_fire_disconnect_when_transport_disconnect() {
let mut endpoint = MediaEndpointInternal::new("room1", "peer1", BitrateLimiterType::DynamicWithConsumers, vec![]);
let mut endpoint = MediaEndpointInternal::new("room1", "peer1", EndpointSubscribeScope::RoomManual, BitrateLimiterType::DynamicWithConsumers, vec![]);

endpoint.on_transport(0, TransportIncomingEvent::State(TransportStateEvent::Disconnected));

Expand All @@ -649,7 +686,7 @@ mod tests {

#[test]
fn should_fire_answer_rpc() {
let mut endpoint = MediaEndpointInternal::new("room1", "peer1", BitrateLimiterType::DynamicWithConsumers, vec![]);
let mut endpoint = MediaEndpointInternal::new("room1", "peer1", EndpointSubscribeScope::RoomManual, BitrateLimiterType::DynamicWithConsumers, vec![]);

endpoint.on_transport(0, TransportIncomingEvent::LocalTrackAdded("video_0".to_string(), 1, TrackMeta::new_video(None)));

Expand Down Expand Up @@ -733,6 +770,97 @@ mod tests {
);
}

#[test]
fn should_fire_room_sub_in_scope_auto() {
let mut endpoint = MediaEndpointInternal::new("room1", "peer1", EndpointSubscribeScope::RoomAuto, BitrateLimiterType::DynamicWithConsumers, vec![]);

endpoint.on_start(0);

assert_eq!(endpoint.pop_action(), Some(MediaInternalAction::Cluster(ClusterEndpointOutgoingEvent::SubscribeRoom)));
assert_eq!(endpoint.pop_action(), None);

endpoint.before_drop(1000);

assert_eq!(endpoint.pop_action(), Some(MediaInternalAction::Cluster(ClusterEndpointOutgoingEvent::UnsubscribeRoom)));
assert_eq!(endpoint.pop_action(), None);
}

#[test]
fn should_handle_sub_peer_in_scope_manual() {
let mut endpoint = MediaEndpointInternal::new("room1", "peer1", EndpointSubscribeScope::RoomManual, BitrateLimiterType::DynamicWithConsumers, vec![]);

endpoint.on_start(0);

// on endpoint sub_peer rpc should fire cluster sub_peer
endpoint.on_transport(
0,
TransportIncomingEvent::Rpc(EndpointRpcIn::SubscribePeer(RpcRequest {
req_id: 1,
data: RemotePeer { peer: "peer2".to_string() },
})),
);
assert_eq!(
endpoint.pop_action(),
Some(MediaInternalAction::Cluster(ClusterEndpointOutgoingEvent::SubscribePeer("peer2".to_string())))
);
assert_eq!(
endpoint.pop_action(),
Some(MediaInternalAction::Endpoint(TransportOutgoingEvent::Rpc(EndpointRpcOut::SubscribePeerRes(RpcResponse::success(
1, true
)))))
);
assert_eq!(endpoint.pop_action(), None);

// on endpoint sub_peer rpc should fire cluster sub_peer
endpoint.on_transport(
0,
TransportIncomingEvent::Rpc(EndpointRpcIn::SubscribePeer(RpcRequest {
req_id: 2,
data: RemotePeer { peer: "peer3".to_string() },
})),
);
assert_eq!(
endpoint.pop_action(),
Some(MediaInternalAction::Cluster(ClusterEndpointOutgoingEvent::SubscribePeer("peer3".to_string())))
);
assert_eq!(
endpoint.pop_action(),
Some(MediaInternalAction::Endpoint(TransportOutgoingEvent::Rpc(EndpointRpcOut::SubscribePeerRes(RpcResponse::success(
2, true
)))))
);
assert_eq!(endpoint.pop_action(), None);

// on endpoint unsub_peer rpc should fire cluster unsub_peer
endpoint.on_transport(
0,
TransportIncomingEvent::Rpc(EndpointRpcIn::UnsubscribePeer(RpcRequest {
req_id: 3,
data: RemotePeer { peer: "peer3".to_string() },
})),
);
assert_eq!(
endpoint.pop_action(),
Some(MediaInternalAction::Cluster(ClusterEndpointOutgoingEvent::UnsubscribePeer("peer3".to_string())))
);
assert_eq!(
endpoint.pop_action(),
Some(MediaInternalAction::Endpoint(TransportOutgoingEvent::Rpc(EndpointRpcOut::UnsubscribePeerRes(RpcResponse::success(
3, true
)))))
);
assert_eq!(endpoint.pop_action(), None);

// on endpoint before_drop should fire remain cluster unsub_peer
endpoint.before_drop(1000);

assert_eq!(
endpoint.pop_action(),
Some(MediaInternalAction::Cluster(ClusterEndpointOutgoingEvent::UnsubscribePeer("peer2".to_string())))
);
assert_eq!(endpoint.pop_action(), None);
}

#[test]
fn should_forward_remote_track_stats() {
//TODO
Expand Down
2 changes: 2 additions & 0 deletions packages/endpoint/src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ pub enum EndpointRpcOut {
TrackAdded(TrackInfo),
TrackUpdated(TrackInfo),
TrackRemoved(TrackInfo),
SubscribePeerRes(RpcResponse<bool>),
UnsubscribePeerRes(RpcResponse<bool>),
}

#[derive(Debug, PartialEq, Eq)]
Expand Down
Loading

0 comments on commit 68a6f4f

Please sign in to comment.