From b3b5344a45248292684b9983b7d62e2afb29d20a Mon Sep 17 00:00:00 2001 From: hongcha Date: Mon, 18 Dec 2023 23:50:32 +0800 Subject: [PATCH] feat: layer extensions --- src/dto/mod.rs | 1 + src/dto/req.rs | 7 ++++ src/forward/forward_internal.rs | 61 ++++++++++++++++++++++++++++++- src/{layer.rs => forward/info.rs} | 1 - src/forward/mod.rs | 11 +++++- src/main.rs | 54 +++++++++++++++++++++++++-- src/path/manager.rs | 18 ++++++++- 7 files changed, 145 insertions(+), 8 deletions(-) create mode 100644 src/dto/mod.rs create mode 100644 src/dto/req.rs rename src/{layer.rs => forward/info.rs} (86%) diff --git a/src/dto/mod.rs b/src/dto/mod.rs new file mode 100644 index 00000000..1b7e4723 --- /dev/null +++ b/src/dto/mod.rs @@ -0,0 +1 @@ +pub mod req; diff --git a/src/dto/req.rs b/src/dto/req.rs new file mode 100644 index 00000000..6b107f53 --- /dev/null +++ b/src/dto/req.rs @@ -0,0 +1,7 @@ +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct SelectLayer { + #[serde(rename = "encodingId")] + pub encoding_id: Option, +} diff --git a/src/forward/forward_internal.rs b/src/forward/forward_internal.rs index 0312de39..665dc959 100644 --- a/src/forward/forward_internal.rs +++ b/src/forward/forward_internal.rs @@ -30,6 +30,7 @@ use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; use webrtc::track::track_remote::TrackRemote; +use crate::forward::info::Layer; use crate::media; use crate::AppError; @@ -219,7 +220,65 @@ impl PeerForwardInternal { } } } - return Err(anyhow::anyhow!("anchor is none")); + Err(anyhow::anyhow!("anchor svc rids error")) + } + + pub async fn select_layer(&self, key: String, layer: Option) -> Result<()> { + let rid = if let Some(layer) = layer { + layer.encoding_id + } else { + self.publish_svc_rids().await?[0].clone() + }; + let peer: Option = self + .subscribe_group + .read() + .await + .iter() + .filter(|p| p.get_key() == key) + .map(|p| p.clone()) + .next(); + if let Some(peer) = peer { + let anchor_track_forward_map = self.anchor_track_forward_map.write().await; + for (track_remote, track_forward) in anchor_track_forward_map.iter() { + if track_remote.0.rid() == rid && track_remote.0.kind() == RTPCodecType::Video { + for (track_remote_original, track_forward_original) in + anchor_track_forward_map.iter() + { + if track_remote_original.0.kind() != RTPCodecType::Video { + continue; + } + let mut subscription_group = + track_forward_original.subscription_group.write().await; + if subscription_group.contains_key(&peer) { + if track_remote_original.0.rid() == rid { + return Ok(()); + } + let sender = subscription_group.remove(&peer).unwrap(); + drop(subscription_group); + track_forward + .subscription_group + .write() + .await + .insert(peer.clone(), sender); + let _ = track_forward + .rtcp_send + .try_send(RtcpMessage::PictureLossIndication); + info!( + "[{}] [subscribe] [{}] select layer {} to {} ", + self.id, + peer.get_key(), + track_remote_original.0.rid(), + rid + ); + return Ok(()); + } + } + } + } + Err(anyhow::anyhow!("not found layer")) + } else { + Err(anyhow::anyhow!("not found key")) + } } pub async fn remove_subscribe(&self, peer: Arc) -> Result<()> { diff --git a/src/layer.rs b/src/forward/info.rs similarity index 86% rename from src/layer.rs rename to src/forward/info.rs index a6099d8d..a8cd5796 100644 --- a/src/layer.rs +++ b/src/forward/info.rs @@ -4,5 +4,4 @@ use serde::{Deserialize, Serialize}; pub struct Layer { #[serde(rename = "encodingId")] pub encoding_id: String, - // TODO Other fields } diff --git a/src/forward/mod.rs b/src/forward/mod.rs index 8ae86cd2..093620b4 100644 --- a/src/forward/mod.rs +++ b/src/forward/mod.rs @@ -13,11 +13,12 @@ use webrtc::rtp_transceiver::rtp_codec::RTPCodecType; use webrtc::sdp::{MediaDescription, SessionDescription}; use crate::forward::forward_internal::{get_peer_key, PeerForwardInternal}; -use crate::layer::Layer; +use crate::forward::info::Layer; use crate::AppError; use crate::{media, metrics}; mod forward_internal; +pub mod info; mod rtcp; mod track_match; @@ -170,7 +171,15 @@ impl PeerForward { Err(anyhow::anyhow!("not layers")) } } + + pub async fn select_layer(&self, key: String, layer: Option) -> Result<()> { + if !self.internal.publish_is_svc().await { + return Err(anyhow::anyhow!("anchor svc is not enabled")); + } + self.internal.select_layer(key, layer).await + } } + async fn peer_complete( offer: RTCSessionDescription, peer: Arc, diff --git a/src/main.rs b/src/main.rs index 8d8aff26..34805fad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use axum::http::{HeaderMap, Uri}; use axum::routing::get; +use axum::Json; use axum::{ extract::{Path, State}, http::StatusCode, @@ -12,6 +13,7 @@ use axum::{ routing::post, Router, }; +use forward::info::Layer; use http::header::ToStrError; use log::info; use thiserror::Error; @@ -27,11 +29,12 @@ use {http::header, rust_embed::RustEmbed}; use crate::auth::ManyValidate; use crate::config::Config; +use crate::dto::req::SelectLayer; mod auth; mod config; +mod dto; mod forward; -mod layer; mod media; mod metrics; mod path; @@ -92,9 +95,13 @@ async fn main() { post(whep) .patch(add_ice_candidate) .delete(remove_path_key) - .layer(auth_layer) + .layer(auth_layer.clone()) .options(ice_server_config), ) + .route( + "/whep/:id/layer", + get(get_layer).post(select_layer).layer(auth_layer), + ) .route("/metrics", get(metrics)) .with_state(app_state); app = static_server(app); @@ -195,8 +202,13 @@ async fn whep( .header("E-Tag", key) .header("Location", uri.to_string()); if state.paths.layers(id).await.is_ok() { - builder = builder.header("Link", format!("<{}/layer>; rel=\"urn:ietf:params:whep:ext:core:layer\"", uri.to_string())) - .header("Link", format!("<{}/sse_info>; rel=\"urn:ietf:params:whep:ext:core:server-sent-events\"; events=\"layers\"", uri.to_string())) + builder = builder.header( + "Link", + format!( + "<{}/layer>; rel=\"urn:ietf:params:whep:ext:core:layer\"", + uri.to_string() + ), + ) } Ok(builder.body(answer.sdp)?) } @@ -254,6 +266,40 @@ async fn ice_server_config(State(state): State) -> AppResult, + Path(id): Path, +) -> AppResult>> { + let layers = state.paths.layers(id).await?; + Ok(Json(layers)) +} + +async fn select_layer( + State(state): State, + Path(id): Path, + header: HeaderMap, + Json(layer): Json, +) -> AppResult { + let key = header + .get("If-Match") + .ok_or(AppError::from(anyhow::anyhow!("If-Match is required")))? + .to_str()? + .to_string(); + state + .paths + .select_layer( + id, + key, + if let Some(encoding_id) = layer.encoding_id { + Some(Layer { encoding_id }) + } else { + None + }, + ) + .await?; + Ok("".to_string()) +} + fn link_header(ice_servers: Vec) -> Vec { ice_servers .into_iter() diff --git a/src/path/manager.rs b/src/path/manager.rs index 59a95399..5617ef00 100644 --- a/src/path/manager.rs +++ b/src/path/manager.rs @@ -8,8 +8,8 @@ use webrtc::{ peer_connection::sdp::session_description::RTCSessionDescription, }; +use crate::forward::info::Layer; use crate::forward::PeerForward; -use crate::layer::Layer; use crate::AppError; #[derive(Clone)] @@ -103,4 +103,20 @@ impl Manager { Err(anyhow::anyhow!("resource not exists")) } } + + pub async fn select_layer( + &self, + path: String, + key: String, + layer: Option, + ) -> Result<()> { + let paths = self.paths.read().await; + let forward = paths.get(&path).cloned(); + drop(paths); + if let Some(forward) = forward { + forward.select_layer(key, layer).await + } else { + Err(anyhow::anyhow!("resource not exists")) + } + } }