From be6191bad5a019286feb9eff3b2506ac00b3e3d7 Mon Sep 17 00:00:00 2001 From: giangndm <45644921+giangndm@users.noreply.github.com> Date: Mon, 15 Apr 2024 10:11:32 +0700 Subject: [PATCH] feat: ext worker communication (#160) This PR implement ext worker communication for better performance. This PR also includes a whip-whep example for showing how it can be used. --- Cargo.toml | 2 +- bin/src/main.rs | 18 +- examples/Cargo.toml | 4 +- examples/whip-whep/Cargo.toml | 20 + examples/whip-whep/public/index.html | 10 + examples/whip-whep/public/whep/index.html | 16 + examples/whip-whep/public/whep/whep.demo.js | 36 + examples/whip-whep/public/whep/whep.js | 494 +++++++++++ examples/whip-whep/public/whip/index.html | 16 + examples/whip-whep/public/whip/whip.demo.js | 58 ++ examples/whip-whep/public/whip/whip.js | 395 +++++++++ examples/whip-whep/src/http.rs | 113 +++ examples/whip-whep/src/main.rs | 150 ++++ examples/whip-whep/src/sfu/cluster.rs | 138 +++ examples/whip-whep/src/sfu/media.rs | 31 + examples/whip-whep/src/sfu/mod.rs | 388 +++++++++ examples/whip-whep/src/sfu/shared_port.rs | 72 ++ examples/whip-whep/src/sfu/whep.rs | 197 +++++ examples/whip-whep/src/sfu/whip.rs | 173 ++++ examples/whip-whep/src/worker.rs | 365 ++++++++ packages/network/Cargo.toml | 2 +- packages/network/src/base/buf.rs | 409 --------- packages/network/src/base/feature.rs | 37 +- packages/network/src/base/mod.rs | 3 +- packages/network/src/base/secure.rs | 6 +- packages/network/src/base/service.rs | 18 +- packages/network/src/controller_plane.rs | 93 +- .../src/controller_plane/neighbours.rs | 2 +- packages/network/src/data_plane.rs | 155 +++- packages/network/src/data_plane/connection.rs | 6 +- packages/network/src/data_plane/features.rs | 23 +- packages/network/src/data_plane/services.rs | 8 +- packages/network/src/features/alias.rs | 2 +- packages/network/src/features/data.rs | 1 + .../network/src/features/pubsub/controller.rs | 160 +++- .../features/pubsub/controller/feedbacks.rs | 211 +++++ .../features/pubsub/controller/local_relay.rs | 46 +- .../pubsub/controller/remote_relay.rs | 99 ++- .../features/pubsub/controller/source_hint.rs | 808 ++++++++++++++++++ packages/network/src/features/pubsub/mod.rs | 17 +- packages/network/src/features/pubsub/msg.rs | 76 ++ .../network/src/features/pubsub/worker.rs | 56 +- packages/network/src/features/router_sync.rs | 1 + packages/network/src/features/socket.rs | 4 +- packages/network/src/features/vpn.rs | 19 +- packages/network/src/lib.rs | 41 +- .../src/secure/encryption/x25519_dalek_aes.rs | 46 +- .../network/src/services/manual_discovery.rs | 4 +- .../network/src/services/visualization.rs | 4 +- packages/network/src/worker.rs | 217 +++++ packages/network/tests/feature_alias.rs | 4 +- packages/network/tests/feature_pubsub.rs | 207 ++++- packages/network/tests/feature_router_sync.rs | 4 +- .../network/tests/service_visualization.rs | 41 +- packages/network/tests/simulator.rs | 236 ++--- packages/runner/Cargo.toml | 3 +- packages/runner/examples/simple_kv.rs | 20 +- packages/runner/examples/simple_node.rs | 10 +- packages/runner/run-example-debug.sh | 3 + packages/runner/run-example-release.sh | 3 + packages/runner/src/builder.rs | 19 +- .../src/{tasks/data_plane => }/history.rs | 2 +- packages/runner/src/lib.rs | 50 +- packages/runner/src/tasks/controller_plane.rs | 106 --- packages/runner/src/tasks/data_plane.rs | 214 ----- .../tasks/event_convert/controller_plane.rs | 58 -- .../src/tasks/event_convert/data_plane.rs | 48 -- .../runner/src/tasks/event_convert/mod.rs | 2 - packages/runner/src/tasks/mod.rs | 247 ------ packages/runner/src/worker_inner.rs | 278 ++++++ packages/runner/tests/feature_dht_kv.rs | 112 +++ 71 files changed, 5445 insertions(+), 1492 deletions(-) create mode 100644 examples/whip-whep/Cargo.toml create mode 100644 examples/whip-whep/public/index.html create mode 100644 examples/whip-whep/public/whep/index.html create mode 100644 examples/whip-whep/public/whep/whep.demo.js create mode 100644 examples/whip-whep/public/whep/whep.js create mode 100644 examples/whip-whep/public/whip/index.html create mode 100644 examples/whip-whep/public/whip/whip.demo.js create mode 100644 examples/whip-whep/public/whip/whip.js create mode 100644 examples/whip-whep/src/http.rs create mode 100644 examples/whip-whep/src/main.rs create mode 100644 examples/whip-whep/src/sfu/cluster.rs create mode 100644 examples/whip-whep/src/sfu/media.rs create mode 100644 examples/whip-whep/src/sfu/mod.rs create mode 100644 examples/whip-whep/src/sfu/shared_port.rs create mode 100644 examples/whip-whep/src/sfu/whep.rs create mode 100644 examples/whip-whep/src/sfu/whip.rs create mode 100644 examples/whip-whep/src/worker.rs delete mode 100644 packages/network/src/base/buf.rs create mode 100644 packages/network/src/features/pubsub/controller/feedbacks.rs create mode 100644 packages/network/src/features/pubsub/controller/source_hint.rs create mode 100644 packages/network/src/worker.rs create mode 100644 packages/runner/run-example-debug.sh create mode 100644 packages/runner/run-example-release.sh rename packages/runner/src/{tasks/data_plane => }/history.rs (97%) delete mode 100644 packages/runner/src/tasks/controller_plane.rs delete mode 100644 packages/runner/src/tasks/data_plane.rs delete mode 100644 packages/runner/src/tasks/event_convert/controller_plane.rs delete mode 100644 packages/runner/src/tasks/event_convert/data_plane.rs delete mode 100644 packages/runner/src/tasks/event_convert/mod.rs delete mode 100644 packages/runner/src/tasks/mod.rs create mode 100644 packages/runner/src/worker_inner.rs create mode 100644 packages/runner/tests/feature_dht_kv.rs diff --git a/Cargo.toml b/Cargo.toml index a8599146..af669722 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,4 +23,4 @@ parking_lot = "0.12" env_logger = "0.11" clap = { version = "4.4", features = ["derive", "env"] } mockall = "0.12.1" -sans-io-runtime = { git = "https://github.com/8xff/sans-io-runtime.git", rev = "c2d0c78ae5bfa7ab9c7942a7333a21a439fc5edb"} \ No newline at end of file +sans-io-runtime = { git = "https://github.com/8xFF/sans-io-runtime.git", rev = "1be9705b4fe9852b7c1ac66dc610fedf94a83971" } \ No newline at end of file diff --git a/bin/src/main.rs b/bin/src/main.rs index dcbd6922..5acaa1c4 100644 --- a/bin/src/main.rs +++ b/bin/src/main.rs @@ -1,12 +1,11 @@ -use atm0s_sdn::sans_io_runtime::Owner; use atm0s_sdn::secure::StaticKeyAuthorization; use atm0s_sdn::services::visualization; -use atm0s_sdn::tasks::{SdnExtIn, SdnExtOut}; use atm0s_sdn::{ - sans_io_runtime::backend::{MioBackend, PollBackend, PollingBackend}, + sans_io_runtime::backend::{PollBackend, PollingBackend}, services::visualization::ConnectionInfo, }; use atm0s_sdn::{NodeAddr, NodeId}; +use atm0s_sdn::{SdnBuilder, SdnExtIn, SdnExtOut, SdnOwner}; use clap::{Parser, ValueEnum}; use futures_util::{SinkExt, StreamExt}; #[cfg(not(feature = "embed"))] @@ -14,7 +13,6 @@ use poem::endpoint::StaticFilesEndpoint; #[cfg(feature = "embed")] use poem::endpoint::{EmbeddedFileEndpoint, EmbeddedFilesEndpoint}; use poem::{ - endpoint::StaticFilesEndpoint, get, handler, listener::TcpListener, web::{ @@ -38,8 +36,6 @@ use std::{ }; use tokio::sync::Mutex; -use atm0s_sdn::builder::SdnBuilder; - #[cfg(feature = "embed")] #[derive(RustEmbed)] #[folder = "public"] @@ -49,7 +45,6 @@ pub struct Files; enum BackendType { Poll, Polling, - Mio, } /// Simple program to running a node @@ -256,15 +251,14 @@ async fn main() { } let mut controller = match args.backend { - BackendType::Mio => builder.build::>(args.workers), - BackendType::Poll => builder.build::>(args.workers), - BackendType::Polling => builder.build::>(args.workers), + BackendType::Poll => builder.build::>(args.workers), + BackendType::Polling => builder.build::>(args.workers), }; let ctx = Arc::new(Mutex::new(WebsocketCtx::new())); if args.collector { - controller.send_to(Owner::worker(0), SdnExtIn::ServicesControl(visualization::SERVICE_ID.into(), visualization::Control::Subscribe)); + controller.send_to(0, SdnExtIn::ServicesControl(visualization::SERVICE_ID.into(), visualization::Control::Subscribe)); let ctx_c = ctx.clone(); tokio::spawn(async move { let route = Route::new().at("/ws", get(ws.data(ctx_c))); @@ -292,7 +286,7 @@ async fn main() { } while let Some(event) = controller.pop_event() { match event { - SdnExtOut::ServicesEvent(event) => match event { + SdnExtOut::ServicesEvent(_service, event) => match event { visualization::Event::GotAll(all) => { log::info!("Got all: {:?}", all); ctx.lock().await.set_snapshot(all); diff --git a/examples/Cargo.toml b/examples/Cargo.toml index fe27596b..dd09a749 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -2,7 +2,7 @@ resolver = "2" members = [ "quic-tunnel" -] +, "whip-whep"] [workspace.package] version = "0.1.0" @@ -15,4 +15,4 @@ tracing-subscriber = "0.3" signal-hook = "0.3" clap = { version = "4.4", features = ["derive", "env"] } tokio = { version = "1", features = ["full"] } -log = "0.4" \ No newline at end of file +log = "0.4" diff --git a/examples/whip-whep/Cargo.toml b/examples/whip-whep/Cargo.toml new file mode 100644 index 00000000..a56d1d13 --- /dev/null +++ b/examples/whip-whep/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "whip-whep" +version.workspace = true +edition.workspace = true +publish.workspace = true + +[dependencies] +sans-io-runtime = { git = "https://github.com/8xFF/sans-io-runtime.git", rev = "1be9705b4fe9852b7c1ac66dc610fedf94a83971" } +atm0s-sdn = { path = "../../packages/runner" } +derive_more = "0.99.17" +str0m = "0.5.0" +tiny_http = "0.12.0" +signal-hook = "0.3.17" +env_logger = "0.11.3" +log.workspace = true +faster-stun = "1.0.2" +clap.workspace = true +serde = "1.0.197" +bincode = "1.3.3" +rand = "0.8.5" diff --git a/examples/whip-whep/public/index.html b/examples/whip-whep/public/index.html new file mode 100644 index 00000000..4018470d --- /dev/null +++ b/examples/whip-whep/public/index.html @@ -0,0 +1,10 @@ + + + + + + \ No newline at end of file diff --git a/examples/whip-whep/public/whep/index.html b/examples/whip-whep/public/whep/index.html new file mode 100644 index 00000000..c1dfe22f --- /dev/null +++ b/examples/whip-whep/public/whep/index.html @@ -0,0 +1,16 @@ + + + Whep + + + +
+ + + +
+
+ +
+ + \ No newline at end of file diff --git a/examples/whip-whep/public/whep/whep.demo.js b/examples/whip-whep/public/whep/whep.demo.js new file mode 100644 index 00000000..4009f28c --- /dev/null +++ b/examples/whip-whep/public/whep/whep.demo.js @@ -0,0 +1,36 @@ +import { WHEPClient } from "./whep.js" + +window.start = async () => { + console.log("Will start"); + //Create peerconnection + const pc = window.pc = new RTCPeerConnection(); + + //Add recv only transceivers + pc.addTransceiver("audio", { direction: 'recvonly' }); + pc.addTransceiver("video", { direction: 'recvonly' }); + + let stream = new MediaStream(); + document.querySelector("video").srcObject = stream; + pc.ontrack = (event) => { + stream.addTrack(event.track); + } + + //Create whep client + const whep = new WHEPClient(); + + const url = "/whep/endpoint"; + const token = document.getElementById("room-id").value; + + //Start viewing + whep.view(pc, url, token); + + window.whep_instance = whep; +} + +window.stop = async () => { + if (window.whep_instance) { + window.whep_instance.stop(); + } + + document.getElementById("video").srcObject = null; +} \ No newline at end of file diff --git a/examples/whip-whep/public/whep/whep.js b/examples/whip-whep/public/whep/whep.js new file mode 100644 index 00000000..f047b22e --- /dev/null +++ b/examples/whip-whep/public/whep/whep.js @@ -0,0 +1,494 @@ +const Extensions = { + Core: { + ServerSentEvents: "urn:ietf:params:whep:ext:core:server-sent-events", + Layer : "urn:ietf:params:whep:ext:core:layer", + } +} + + +export class WHEPClient extends EventTarget +{ + constructor() + { + super(); + //Ice properties + this.iceUsername = null; + this.icePassword = null; + //Pending candidadtes + this.candidates = []; + this.endOfcandidates = false; + } + + async view(pc, url, token) + { + //If already publishing + if (this.pc) + throw new Error("Already viewing") + + //Store pc object and token + this.token = token; + this.pc = pc; + + //Listen for state change events + pc.onconnectionstatechange = (event) => + { + switch (pc.connectionState) + { + case "connected": + // The connection has become fully connected + break; + case "disconnected": + case "failed": + // One or more transports has terminated unexpectedly or in an error + break; + case "closed": + // The connection has been closed + break; + } + } + + //Listen for candidates + pc.onicecandidate = (event) => + { + + if (event.candidate) + { + //Ignore candidates not from the first m line + if (event.candidate.sdpMLineIndex > 0) + //Skip + return; + //Store candidate + this.candidates.push(event.candidate); + } else + { + //No more candidates + this.endOfcandidates = true; + } + //Schedule trickle on next tick + if (!this.iceTrickeTimeout) + this.iceTrickeTimeout = setTimeout(() => this.trickle(), 0); + } + //Create SDP offer + const offer = await pc.createOffer(); + + //Request headers + const headers = { + "Content-Type": "application/sdp" + }; + + //If token is set + if (token) + headers["Authorization"] = "Bearer " + token; + + //Do the post request to the WHEP endpoint with the SDP offer + const fetched = await fetch(url, { + method: "POST", + body: offer.sdp, + headers + }); + + if (!fetched.ok) + throw new Error("Request rejected with status " + fetched.status) + if (!fetched.headers.get("location")) + throw new Error("Response missing location header") + + //Get the resource url + this.resourceURL = fetched.headers.get("location"); + + //Get the links + const links = {}; + + //If the response contained any + if (fetched.headers.has("link")) + { + //Get all links headers + const linkHeaders = fetched.headers.get("link").split(/,\s+(?=<)/) + + //For each one + for (const header of linkHeaders) + { + try + { + let rel, params = {}; + //Split in parts + const items = header.split(";"); + //Create url server + const url = items[0].trim().replace(/<(.*)>/, "$1").trim(); + //For each other item + for (let i = 1; i < items.length; ++i) + { + //Split into key/val + const subitems = items[i].split(/=(.*)/); + //Get key + const key = subitems[0].trim(); + //Unquote value + const value = subitems[1] + ? subitems[1] + .trim() + .replaceAll('"', '') + .replaceAll("'", "") + : subitems[1]; + //Check if it is the rel attribute + if (key == "rel") + //Get rel value + rel = value; + else + //Unquote value and set them + params[key] = value + } + //Ensure it is an ice server + if (!rel) + continue; + if (!links[rel]) + links[rel] = []; + //Add to config + links[rel].push({url, params}); + } catch (e){ + console.error(e) + } + } + } + + //Get extensions url + if (links.hasOwnProperty(Extensions.Core.ServerSentEvents)) + //Get url + this.eventsUrl = new URL(links[Extensions.Core.ServerSentEvents][0].url, url); + if (links.hasOwnProperty(Extensions.Core.Layer)) + this.layerUrl = new URL(links[Extensions.Core.Layer][0].url, url); + + //If we have an event url + if (this.eventsUrl) + { + //Get supported events + const events = links[Extensions.Core.ServerSentEvents]["events"] + ? links[Extensions.Core.ServerSentEvents]["events"].split(" ") + : ["active","inactive","layers","viewercount"]; + //Request headers + const headers = { + "Content-Type": "application/json" + }; + + //If token is set + if (this.token) + headers["Authorization"] = "Bearer " + this.token; + + //Do the post request to the whep resource + fetch(this.eventsUrl, { + method: "POST", + body: JSON.stringify(events), + headers + }).then((fetched)=>{ + //If the event channel could be created + if (!fetched.ok) + return; + //Get the resource url + const sseUrl = new URL(fetched.headers.get("location"), this.eventsUrl); + //Open it + this.eventSource = new EventSource(sseUrl); + this.eventSource.onopen = (event) => console.log(event); + this.eventSource.onerror = (event) => console.log(event); + //Listen for events + this.eventSource.onmessage = (event) => { + console.dir(event); + this.dispatchEvent(event); + }; + }); + } + + //Get current config + const config = pc.getConfiguration(); + + //If it has ice server info and it is not overridden by the client + if ((!config.iceServer || !config.iceServer.length) && links.hasOwnProperty("ice-server")) + { + //ICe server config + config.iceServers = []; + + //For each one + for (const server of links["ice-server"]) + { + try + { + //Create ice server + const iceServer = { + urls : server.url + } + //For each other param + for (const [key,value] of Object.entries(server.params)) + { + //Get key in cammel case + const cammelCase = key.replace(/([-_][a-z])/ig, $1 => $1.toUpperCase().replace('-', '').replace('_', '')) + //Unquote value and set them + iceServer[cammelCase] = value; + } + //Add to config + //config.iceServers.push(iceServer); + } catch (e){ + } + } + + //If any configured + if (config.iceServers.length) + //Set it + pc.setConfiguration(config); + } + + //Get the SDP answer + const answer = await fetched.text(); + console.log(answer); + + //Schedule trickle on next tick + if (!this.iceTrickeTimeout) + this.iceTrickeTimeout = setTimeout(() => this.trickle(), 0); + + //Set local description + await pc.setLocalDescription(offer); + + // TODO: chrome is returning a wrong value, so don't use it for now + //try { + // //Get local ice properties + // const local = this.pc.getTransceivers()[0].sender.transport.iceTransport.getLocalParameters(); + // //Get them for transport + // this.iceUsername = local.usernameFragment; + // this.icePassword = local.password; + //} catch (e) { + //Fallback for browsers not supporting ice transport + this.iceUsername = offer.sdp.match(/a=ice-ufrag:(.*)\r\n/)[1]; + this.icePassword = offer.sdp.match(/a=ice-pwd:(.*)\r\n/)[1]; + //} + + //And set remote description + await pc.setRemoteDescription({ type: "answer", sdp: answer }); + } + + restart() + { + //Set restart flag + this.restartIce = true; + + //Schedule trickle on next tick + if (!this.iceTrickeTimeout) + this.iceTrickeTimeout = setTimeout(() => this.trickle(), 0); + } + + async trickle() + { + //Clear timeout + this.iceTrickeTimeout = null; + + //Check if there is any pending data + if (!(this.candidates.length || this.endOfcandidates || this.restartIce) || !this.resourceURL) + //Do nothing + return; + + //Get data + const candidates = this.candidates; + let endOfcandidates = this.endOfcandidates; + const restartIce = this.restartIce; + + //Clean pending data before async operation + this.candidates = []; + this.endOfcandidates = false; + this.restartIce = false; + + //If we need to restart + if (restartIce) + { + //Restart ice + this.pc.restartIce(); + //Create a new offer + const offer = await this.pc.createOffer({ iceRestart: true }); + //Update ice + this.iceUsername = offer.sdp.match(/a=ice-ufrag:(.*)\r\n/)[1]; + this.icePassword = offer.sdp.match(/a=ice-pwd:(.*)\r\n/)[1]; + //Set it + await this.pc.setLocalDescription(offer); + //Clean end of candidates flag as new ones will be retrieved + endOfcandidates = false; + } + //Prepare fragment + let fragment = + "a=ice-ufrag:" + this.iceUsername + "\r\n" + + "a=ice-pwd:" + this.icePassword + "\r\n"; + //Get peerconnection transceivers + const transceivers = this.pc.getTransceivers(); + //Get medias + const medias = {}; + //If doing something else than a restart + if (candidates.length || endOfcandidates) + //Create media object for first media always + medias[transceivers[0].mid] = { + mid: transceivers[0].mid, + kind: transceivers[0].receiver.track.kind, + candidates: [], + }; + //For each candidate + for (const candidate of candidates) + { + //Get mid for candidate + const mid = candidate.sdpMid + //Get associated transceiver + const transceiver = transceivers.find(t => t.mid == mid); + //Get media + let media = medias[mid]; + //If not found yet + if (!media) + //Create media object + media = medias[mid] = { + mid, + kind: transceiver.receiver.track.kind, + candidates: [], + }; + //Add candidate + media.candidates.push(candidate); + } + //For each media + for (const media of Object.values(medias)) + { + //Add media to fragment + fragment += + "m=" + media.kind + " 9 RTP/AVP 0\r\n" + + "a=mid:" + media.mid + "\r\n"; + //Add candidate + for (const candidate of media.candidates) + fragment += "a=" + candidate.candidate + "\r\n"; + if (endOfcandidates) + fragment += "a=end-of-candidates\r\n"; + } + + //Request headers + const headers = { + "Content-Type": "application/trickle-ice-sdpfrag" + }; + + //If token is set + if (this.token) + headers["Authorization"] = "Bearer " + this.token; + + //Do the post request to the WHEP resource + const fetched = await fetch(this.resourceURL, { + method: "PATCH", + body: fragment, + headers + }); + if (!fetched.ok) + throw new Error("Request rejected with status " + fetched.status) + + //If we have got an answer + if (fetched.status == 200) + { + //Get the SDP answer + const answer = await fetched.text(); + //Get remote icename and password + const iceUsername = answer.match(/a=ice-ufrag:(.*)\r\n/)[1]; + const icePassword = answer.match(/a=ice-pwd:(.*)\r\n/)[1]; + + //Get current remote rescription + const remoteDescription = this.pc.remoteDescription; + + //Patch + remoteDescription.sdp = remoteDescription.sdp.replaceAll(/(a=ice-ufrag:)(.*)\r\n/gm, "$1" + iceUsername + "\r\n"); + remoteDescription.sdp = remoteDescription.sdp.replaceAll(/(a=ice-pwd:)(.*)\r\n/gm, "$1" + icePassword + "\r\n"); + + //Set it + await this.pc.setRemoteDescription(remoteDescription); + } + } + + async mute(muted) + { + //Request headers + const headers = { + "Content-Type": "application/json" + }; + + //If token is set + if (this.token) + headers["Authorization"] = "Bearer " + this.token; + + //Do the post request to the whep resource + const fetched = await fetch(this.resourceURL, { + method: "POST", + body: JSON.stringify(muted), + headers + }); + } + + async selectLayer(layer) + { + if (!this.layerUrl) + throw new Error("whep resource does not support layer selection"); + + //Request headers + const headers = { + "Content-Type": "application/json" + }; + + //If token is set + if (this.token) + headers["Authorization"] = "Bearer " + this.token; + + //Do the post request to the whep resource + const fetched = await fetch(this.layerUrl, { + method: "POST", + body: JSON.stringify(layer), + headers + }); + } + + async unselectLayer() + { + if (!this.layerUrl) + throw new Error("whep resource does not support layer selection"); + + + //Request headers + const headers = {}; + + //If token is set + if (this.token) + headers["Authorization"] = "Bearer " + this.token; + + //Do the post request to the whep resource + const fetched = await fetch(this.layerUrl, { + method: "DELETE", + headers + }); + } + + async stop() + { + if (!this.pc) + { + // Already stopped + return + } + + //Cancel any pending timeout + this.iceTrickeTimeout = clearTimeout(this.iceTrickeTimeout); + + //Close peerconnection + this.pc.close(); + + //Null + this.pc = null; + + //If we don't have the resource url + if (!this.resourceURL) + throw new Error("WHEP resource url not available yet"); + + //Request headers + const headers = { + }; + + //If token is set + if (this.token) + headers["Authorization"] = "Bearer " + this.token; + + //Send a delete + await fetch(this.resourceURL, { + method: "DELETE", + headers + }); + } +}; \ No newline at end of file diff --git a/examples/whip-whep/public/whip/index.html b/examples/whip-whep/public/whip/index.html new file mode 100644 index 00000000..547e06a9 --- /dev/null +++ b/examples/whip-whep/public/whip/index.html @@ -0,0 +1,16 @@ + + + Whip + + + +
+ + + +
+
+ +
+ + \ No newline at end of file diff --git a/examples/whip-whep/public/whip/whip.demo.js b/examples/whip-whep/public/whip/whip.demo.js new file mode 100644 index 00000000..866ca047 --- /dev/null +++ b/examples/whip-whep/public/whip/whip.demo.js @@ -0,0 +1,58 @@ +import { WHIPClient } from "./whip.js" + +window.start = async () => { + console.log("Will start"); + if (window.whip_instance) { + window.whip_instance.stop(); + } + + if (window.stream_instance) { + window.stream_instance.getTracks().forEach(track => track.stop()); + } + + //Get mic+cam + const stream = await navigator.mediaDevices.getUserMedia({audio:true, video:true}); + + document.getElementById("video").srcObject = stream; + + //Create peerconnection + const pc = new RTCPeerConnection(); + + //Send all tracks + for (const track of stream.getTracks()) { + //You could add simulcast too here + pc.addTransceiver(track, { + direction: "sendonly", + streams: [stream], + // sendEncodings: [ + // { rid: "0", active: true, scaleResolutionDownBy: 2}, + // { rid: "1", active: true, scaleResolutionDownBy: 2}, + // { rid: "2", active: true }, + // ], + }); + } + + //Create whip client + const whip = new WHIPClient(); + + const url = "/whip/endpoint"; + const token = document.getElementById("room-id").value; + + //Start publishing + whip.publish(pc, url, token); + + window.whip_instance = whip; + window.stream_instance = stream; +} + +window.stop = async () => { + if (window.whip_instance) { + window.whip_instance.stop(); + } + + if (window.stream_instance) { + window.stream_instance.getTracks().forEach(track => track.stop()); + } + + document.getElementById("video").srcObject = null; +} \ No newline at end of file diff --git a/examples/whip-whep/public/whip/whip.js b/examples/whip-whep/public/whip/whip.js new file mode 100644 index 00000000..a6438022 --- /dev/null +++ b/examples/whip-whep/public/whip/whip.js @@ -0,0 +1,395 @@ + +//import { EventEmitter } from "events"; + +export class WHIPClient +{ + constructor() + { + //Ice properties + this.iceUsername = null; + this.icePassword = null; + //Pending candidadtes + this.candidates = []; + this.endOfcandidates = false; + } + + async publish(pc, url, token) + { + //If already publishing + if (this.pc) + throw new Error("Already publishing") + + //Store pc object and token + this.token = token; + this.pc = pc; + + //Listen for state change events + pc.onconnectionstatechange = (event) =>{ + switch(pc.connectionState) { + case "connected": + // The connection has become fully connected + break; + case "disconnected": + case "failed": + // One or more transports has terminated unexpectedly or in an error + break; + case "closed": + // The connection has been closed + break; + } + } + + //Listen for candidates + pc.onicecandidate = (event)=>{ + + if (event.candidate) + { + //Ignore candidates not from the first m line + if (event.candidate.sdpMLineIndex>0) + //Skip + return; + //Store candidate + this.candidates.push(event.candidate); + } else { + //No more candidates + this.endOfcandidates = true; + } + //Schedule trickle on next tick + if (!this.iceTrickeTimeout) + this.iceTrickeTimeout = setTimeout(()=>this.trickle(),0); + } + //Create SDP offer + const offer = await pc.createOffer(); + + //Request headers + const headers = { + "Content-Type": "application/sdp" + }; + + //If token is set + if (token) + headers["Authorization"] = "Bearer " + token; + + //Do the post request to the WHIP endpoint with the SDP offer + const fetched = await fetch(url, { + method: "POST", + body: offer.sdp, + headers + }); + + if (!fetched.ok) + throw new Error("Request rejected with status " + fetched.status) + if (!fetched.headers.get("location")) { + throw new Error("Response missing location header") + } + + //Get the resource url + this.resourceURL = fetched.headers.get("location"); + + //Get the links + const links = {}; + + //If the response contained any + if (fetched.headers.has("link")) + { + //Get all links headers + const linkHeaders = fetched.headers.get("link").split(/,\s+(?=<)/) + + //For each one + for (const header of linkHeaders) + { + try + { + let rel, params = {}; + //Split in parts + const items = header.split(";"); + //Create url server + const url = items[0].trim().replace(/<(.*)>/, "$1").trim(); + //For each other item + for (let i = 1; i < items.length; ++i) + { + //Split into key/val + const subitems = items[i].split(/=(.*)/); + //Get key + const key = subitems[0].trim(); + //Unquote value + const value = subitems[1] + ? subitems[1] + .trim() + .replaceAll('"', '') + .replaceAll("'", "") + : subitems[1]; + //Check if it is the rel attribute + if (key == "rel") + //Get rel value + rel = value; + else + //Unquote value and set them + params[key] = value + } + //Ensure it is an ice server + if (!rel) + continue; + if (!links[rel]) + links[rel] = []; + //Add to config + links[rel].push({url, params}); + } catch (e){ + console.error(e) + } + } + } + + //Get current config + const config = pc.getConfiguration(); + + //If it has ice server info and it is not overridden by the client + if ((!config.iceServer || !config.iceServer.length) && links.hasOwnProperty("ice-server")) + { + //ICe server config + config.iceServers = []; + + //For each one + for (const server of links["ice-server"]) + { + try + { + //Create ice server + const iceServer = { + urls : server.url + } + //For each other param + for (const [key,value] of Object.entries(server.params)) + { + //Get key in cammel case + const cammelCase = key.replace(/([-_][a-z])/ig, $1 => $1.toUpperCase().replace('-', '').replace('_', '')) + //Unquote value and set them + iceServer[cammelCase] = value; + } + //Add to config + config.iceServers.push(iceServer); + } catch (e){ + } + } + + //If any configured + if (config.iceServers.length) + //Set it + pc.setConfiguration(config); + } + + //Get the SDP answer + const answer = await fetched.text(); + + //Schedule trickle on next tick + if (!this.iceTrickeTimeout) + this.iceTrickeTimeout = setTimeout(() => this.trickle(), 0); + + //Set local description + await pc.setLocalDescription(offer); + + // TODO: chrome is returning a wrong value, so don't use it for now + //try { + // //Get local ice properties + // const local = this.pc.getTransceivers()[0].sender.transport.iceTransport.getLocalParameters(); + // //Get them for transport + // this.iceUsername = local.usernameFragment; + // this.icePassword = local.password; + //} catch (e) { + //Fallback for browsers not supporting ice transport + this.iceUsername = offer.sdp.match(/a=ice-ufrag:(.*)\r\n/)[1]; + this.icePassword = offer.sdp.match(/a=ice-pwd:(.*)\r\n/)[1]; + //} + + //And set remote description + await pc.setRemoteDescription({type:"answer",sdp: answer}); + } + + restart() + { + //Set restart flag + this.restartIce = true; + + //Schedule trickle on next tick + if (!this.iceTrickeTimeout) + this.iceTrickeTimeout = setTimeout(()=>this.trickle(),0); + } + + async trickle() + { + //Clear timeout + this.iceTrickeTimeout = null; + + //Check if there is any pending data + if (!(this.candidates.length || this.endOfcandidates || this.restartIce) || !this.resourceURL ) + //Do nothing + return; + + //Get data + const candidates = this.candidates; + let endOfcandidates = this.endOfcandidates; + const restartIce = this.restartIce; + + //Clean pending data before async operation + this.candidates = []; + this.endOfcandidates = false; + this.restartIce = false; + + //If we need to restart + if (restartIce) + { + //Restart ice + this.pc.restartIce(); + //Create a new offer + const offer = await this.pc.createOffer({iceRestart: true}); + //Update ice + this.iceUsername = offer.sdp.match(/a=ice-ufrag:(.*)\r\n/)[1]; + this.icePassword = offer.sdp.match(/a=ice-pwd:(.*)\r\n/)[1]; + //Set it + await this.pc.setLocalDescription(offer); + //Clean end of candidates flag as new ones will be retrieved + endOfcandidates = false; + } + //Prepare fragment + let fragment = + "a=ice-ufrag:" + this.iceUsername + "\r\n" + + "a=ice-pwd:" + this.icePassword + "\r\n"; + //Get peerconnection transceivers + const transceivers = this.pc.getTransceivers(); + //Get medias + const medias = {}; + //If doing something else than a restart + if (candidates.length || endOfcandidates) + //Create media object for first media always + medias[transceivers[0].mid] = { + mid: transceivers[0].mid, + kind: transceivers[0].receiver.track.kind, + candidates: [], + }; + //For each candidate + for (const candidate of candidates) + { + //Get mid for candidate + const mid = candidate.sdpMid + //Get associated transceiver + const transceiver = transceivers.find(t=>t.mid==mid); + //Get media + let media = medias[mid]; + //If not found yet + if (!media) + //Create media object + media = medias[mid] = { + mid, + kind : transceiver.receiver.track.kind, + candidates: [], + }; + //Add candidate + media.candidates.push(candidate); + } + //For each media + for (const media of Object.values(medias)) + { + //Add media to fragment + fragment += + "m="+ media.kind + " 9 RTP/AVP 0\r\n" + + "a=mid:"+ media.mid + "\r\n"; + //Add candidate + for (const candidate of media.candidates) + fragment += "a=" + candidate.candidate + "\r\n"; + if (endOfcandidates) + fragment += "a=end-of-candidates\r\n"; + } + + //Request headers + const headers = { + "Content-Type": "application/trickle-ice-sdpfrag" + }; + + //If token is set + if (this.token) + headers["Authorization"] = "Bearer " + this.token; + + //Do the post request to the WHIP resource + const fetched = await fetch(this.resourceURL, { + method: "PATCH", + body: fragment, + headers + }); + if (!fetched.ok) + throw new Error("Request rejected with status " + fetched.status) + + //If we have got an answer + if (fetched.status==200) + { + //Get the SDP answer + const answer = await fetched.text(); + //Get remote icename and password + const iceUsername = answer.match(/a=ice-ufrag:(.*)\r\n/)[1]; + const icePassword = answer.match(/a=ice-pwd:(.*)\r\n/)[1]; + + //Get current remote rescription + const remoteDescription = this.pc.remoteDescription; + + //Patch + remoteDescription.sdp = remoteDescription.sdp.replaceAll(/(a=ice-ufrag:)(.*)\r\n/gm , "$1" + iceUsername + "\r\n"); + remoteDescription.sdp = remoteDescription.sdp.replaceAll(/(a=ice-pwd:)(.*)\r\n/gm , "$1" + icePassword + "\r\n"); + + //Set it + await this.pc.setRemoteDescription(remoteDescription); + } + } + + async mute(muted) + { + //Request headers + const headers = { + "Content-Type": "application/json" + }; + + //If token is set + if (this.token) + headers["Authorization"] = "Bearer " + this.token; + + //Do the post request to the WHIP resource + const fetched = await fetch(this.resourceURL, { + method: "POST", + body: JSON.stringify(muted), + headers + }); + } + + async stop() + { + if (!this.pc) { + // Already stopped + return + } + + //Cancel any pending timeout + this.iceTrickeTimeout = clearTimeout(this.iceTrickeTimeout); + + //Close peerconnection + this.pc.close(); + + //Null + this.pc = null; + + //If we don't have the resource url + if (!this.resourceURL) + throw new Error("WHIP resource url not available yet"); + + //Request headers + const headers = { + }; + + //If token is set + if (this.token) + headers["Authorization"] = "Bearer " + this.token; + + //Send a delete + await fetch(this.resourceURL, { + method: "DELETE", + headers + }); + } +}; \ No newline at end of file diff --git a/examples/whip-whep/src/http.rs b/examples/whip-whep/src/http.rs new file mode 100644 index 00000000..05d36337 --- /dev/null +++ b/examples/whip-whep/src/http.rs @@ -0,0 +1,113 @@ +use std::io::Read; +use std::{collections::HashMap, fs::File, net::SocketAddr, path::Path, time::Duration}; +use tiny_http::{Header, Method, Request, Response, Server}; + +#[derive(Debug, Clone)] +pub struct HttpRequest { + pub req_id: u64, + pub method: String, + pub path: String, + pub headers: HashMap, + pub body: Vec, +} + +impl HttpRequest { + pub fn http_auth(&self) -> String { + if let Some(auth) = self.headers.get("Authorization") { + auth.clone() + } else if let Some(auth) = self.headers.get("authorization") { + auth.clone() + } else { + "demo".to_string() + } + } +} + +#[derive(Debug, Clone)] +pub struct HttpResponse { + pub req_id: u64, + pub status: u16, + pub headers: HashMap, + pub body: Vec, +} + +pub struct SimpleHttpServer { + req_id_seed: u64, + server: Server, + reqs: HashMap, +} + +impl SimpleHttpServer { + pub fn new(port: u16) -> Self { + Self { + req_id_seed: 0, + server: Server::http(SocketAddr::from(([0, 0, 0, 0], port))).expect("Should open http port"), + reqs: HashMap::new(), + } + } + + pub fn send_response(&mut self, res: HttpResponse) { + log::info!("sending response for request_id {}, status {}", res.req_id, res.status); + let req = self.reqs.remove(&res.req_id).expect("Should have a request."); + let mut response = Response::from_data(res.body).with_status_code(res.status); + for (k, v) in res.headers { + response.add_header(Header::from_bytes(k.as_bytes(), v.as_bytes()).unwrap()); + } + response.add_header(Header::from_bytes("Access-Control-Allow-Origin", "*").unwrap()); + response.add_header(Header::from_bytes("Access-Control-Allow-Methods", "GET, POST, PATCH, DELETE, OPTIONS").unwrap()); + response.add_header(Header::from_bytes("Access-Control-Allow-Headers", "*").unwrap()); + response.add_header(Header::from_bytes("Access-Control-Allow-Credentials", "true").unwrap()); + req.respond(response).unwrap(); + } + + pub fn recv(&mut self, timeout: Duration) -> Result, std::io::Error> { + let mut request = if let Some(req) = self.server.recv_timeout(timeout)? { + req + } else { + return Ok(None); + }; + if request.url().starts_with("/public") { + if let Ok(file) = File::open(&Path::new(&format!(".{}", request.url()))) { + let mut response = tiny_http::Response::from_file(file); + if request.url().ends_with(".js") { + response.add_header(Header::from_bytes("Content-Type", "application/javascript").unwrap()); + } else if request.url().ends_with(".css") { + response.add_header(Header::from_bytes("Content-Type", "text/css").unwrap()); + } + request.respond(response).expect("Should respond file."); + return Ok(None); + } else { + let response = Response::from_string("Not Found"); + request.respond(response.with_status_code(404)).expect("Should respond 404."); + return Ok(None); + } + } + + if request.method().eq(&Method::Options) { + let mut response = Response::from_string("OK"); + //setting CORS + response.add_header(Header::from_bytes("Access-Control-Allow-Origin", "*").unwrap()); + response.add_header(Header::from_bytes("Access-Control-Allow-Methods", "GET, POST, PATCH, DELETE, OPTIONS").unwrap()); + response.add_header(Header::from_bytes("Access-Control-Allow-Headers", "*").unwrap()); + response.add_header(Header::from_bytes("Access-Control-Allow-Credentials", "true").unwrap()); + + request.respond(response).expect("Should respond options."); + return Ok(None); + } + + log::info!("received request_id {} method: {}, url: {}", self.req_id_seed, request.method(), request.url(),); + + let req_id = self.req_id_seed; + self.req_id_seed += 1; + + let res = Ok(Some(HttpRequest { + req_id, + method: request.method().to_string(), + path: request.url().to_string(), + headers: request.headers().iter().map(|h| (h.field.to_string(), h.value.to_string())).collect(), + body: request.as_reader().bytes().map(|b| b.unwrap()).collect(), + })); + self.reqs.insert(req_id, request); + res + } +} diff --git a/examples/whip-whep/src/main.rs b/examples/whip-whep/src/main.rs new file mode 100644 index 00000000..9c4ca97a --- /dev/null +++ b/examples/whip-whep/src/main.rs @@ -0,0 +1,150 @@ +use std::{ + net::SocketAddr, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, + vec, +}; + +use atm0s_sdn::{ + base::ServiceBuilder, + features::{FeaturesControl, FeaturesEvent}, + secure::{HandshakeBuilderXDA, StaticKeyAuthorization}, + services::visualization, + ControllerPlaneCfg, DataPlaneCfg, DataWorkerHistory, NodeAddr, NodeAddrBuilder, NodeId, Protocol, SdnExtIn, SdnWorkerCfg, +}; +use clap::Parser; +use sans_io_runtime::{backend::PollingBackend, Controller}; + +use worker::{ChannelId, Event, ExtIn, ExtOut, ICfg, SCfg, SC, SE, TC, TW}; + +use crate::worker::{ControllerCfg, RunnerOwner, RunnerWorker, SdnInnerCfg}; + +mod http; +mod sfu; +mod worker; + +/// Quic-tunnel demo application +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +struct Args { + /// Node Id + #[arg(env, short, long, default_value_t = 1)] + node_id: NodeId, + + /// Listen address + #[arg(env, short, long, default_value_t = 10000)] + udp_port: u16, + + /// Address of node we should connect to + #[arg(env, short, long)] + seeds: Vec, + + /// Password for the network + #[arg(env, short, long, default_value = "password")] + password: String, + + /// Workers + #[arg(env, long, default_value_t = 1)] + workers: usize, + + /// Http listen port + #[arg(env, long, default_value_t = 8080)] + http_port: u16, +} + +fn main() { + if std::env::var_os("RUST_LOG").is_none() { + std::env::set_var("RUST_LOG", "info"); + } + let args = Args::parse(); + env_logger::builder().format_timestamp_millis().init(); + + let auth = Arc::new(StaticKeyAuthorization::new(&args.password)); + let history = Arc::new(DataWorkerHistory::default()); + + let mut server = http::SimpleHttpServer::new(args.http_port); + let mut controller = Controller::::default(); + let services: Vec>> = vec![Arc::new(visualization::VisualizationServiceBuilder::::new(false))]; + + let mut addr_builder = NodeAddrBuilder::new(args.node_id); + addr_builder.add_protocol(Protocol::Ip4("192.168.1.27".parse().unwrap())); + addr_builder.add_protocol(Protocol::Udp(args.udp_port)); + let addr = addr_builder.addr(); + log::info!("Node address: {}", addr); + + controller.add_worker::>( + Duration::from_millis(10), + ICfg { + sfu: "192.168.1.27:0".parse().unwrap(), + sdn: SdnInnerCfg { + node_id: args.node_id, + tick_ms: 1000, + udp_port: args.udp_port, + controller: Some(ControllerCfg { + session: 0, + auth, + handshake: Arc::new(HandshakeBuilderXDA), + }), + services: services.clone(), + history: history.clone(), + #[cfg(feature = "vpn")] + vpn_tun_fd: None, + }, + sdn_listen: SocketAddr::from(([0, 0, 0, 0], args.udp_port)), + }, + None, + ); + + for _ in 1..args.workers { + controller.add_worker::>( + Duration::from_millis(10), + ICfg { + sfu: "192.168.1.27:0".parse().unwrap(), + sdn: SdnInnerCfg { + node_id: args.node_id, + tick_ms: 1000, + udp_port: args.udp_port, + controller: None, + services: services.clone(), + history: history.clone(), + #[cfg(feature = "vpn")] + vpn_tun_fd: None, + }, + sdn_listen: SocketAddr::from(([0, 0, 0, 0], args.udp_port)), + }, + None, + ); + } + + for seed in args.seeds { + controller.send_to(0, ExtIn::Sdn(SdnExtIn::ConnectTo(seed))); + } + + let term = Arc::new(AtomicBool::new(false)); + signal_hook::flag::register(signal_hook::consts::SIGINT, Arc::clone(&term)).expect("Should register hook"); + + while let Ok(req) = server.recv(Duration::from_millis(100)) { + if controller.process().is_none() { + break; + } + if term.load(Ordering::Relaxed) { + controller.shutdown(); + } + while let Some(ext) = controller.pop_event() { + match ext { + ExtOut::HttpResponse(resp) => { + server.send_response(resp); + } + ExtOut::Sdn(event) => {} + } + } + if let Some(req) = req { + controller.send_to_best(ExtIn::HttpRequest(req)); + } + } + + log::info!("Server shutdown"); +} diff --git a/examples/whip-whep/src/sfu/cluster.rs b/examples/whip-whep/src/sfu/cluster.rs new file mode 100644 index 00000000..ed252c25 --- /dev/null +++ b/examples/whip-whep/src/sfu/cluster.rs @@ -0,0 +1,138 @@ +use std::{ + collections::HashMap, + hash::{DefaultHasher, Hash, Hasher}, + time::Instant, +}; + +use atm0s_sdn::features::pubsub::{self, ChannelControl, Feedback}; +use str0m::media::KeyframeRequestKind; + +use super::{TrackMedia, WhepOwner, WhipOwner}; + +pub fn room_channel(room: &str) -> u64 { + let mut hasher = DefaultHasher::new(); + room.hash(&mut hasher); + hasher.finish() +} + +pub enum Input { + Pubsub(pubsub::Event), + WhipStart(WhipOwner, String), + WhipStop(WhipOwner), + WhipMedia(WhipOwner, TrackMedia), + WhepStart(WhepOwner, String), + WhepStop(WhepOwner), + WhepRequest(WhepOwner, KeyframeRequestKind), +} + +pub enum Output { + Pubsub(pubsub::Control), + WhepMedia(Vec, TrackMedia), + WhipControl(Vec, KeyframeRequestKind), +} + +pub struct Channel { + whips: Vec, + wheps: Vec, +} + +#[derive(Default)] +pub struct ClusterLogic { + channels: HashMap, + whips: HashMap, + wheps: HashMap, +} + +impl ClusterLogic { + pub fn on_tick(&mut self, now: Instant) -> Option { + None + } + + pub fn on_input(&mut self, now: Instant, input: Input) -> Option { + match input { + Input::Pubsub(pubsub::Event(channel, event)) => match event { + pubsub::ChannelEvent::RouteChanged(_) => None, + pubsub::ChannelEvent::SourceData(_, data) => { + let pkt = TrackMedia::from_buffer(&data); + let channel = self.channels.get(&channel)?; + Some(Output::WhepMedia(channel.wheps.clone(), pkt)) + } + pubsub::ChannelEvent::FeedbackData(fb) => { + let channel = self.channels.get(&channel)?; + let kind = match fb.kind { + 0 => KeyframeRequestKind::Pli, + _ => KeyframeRequestKind::Fir, + }; + Some(Output::WhipControl(channel.whips.clone(), kind)) + } + }, + Input::WhipStart(owner, room) => { + log::info!("WhipStart: {:?}, {:?}", owner, room); + let channel_id = room_channel(&room); + self.whips.insert(owner, channel_id); + let channel = self.channels.entry(channel_id).or_insert(Channel { whips: Vec::new(), wheps: Vec::new() }); + channel.whips.push(owner); + if channel.whips.len() == 1 { + Some(Output::Pubsub(pubsub::Control(channel_id.into(), pubsub::ChannelControl::PubStart))) + } else { + None + } + } + Input::WhipStop(owner) => { + log::info!("WhipStop: {:?}", owner); + let channel_id = self.whips.remove(&owner)?; + let channel = self.channels.get_mut(&channel_id)?; + channel.whips.retain(|&o| o != owner); + if channel.whips.is_empty() { + Some(Output::Pubsub(pubsub::Control(channel_id.into(), pubsub::ChannelControl::PubStop))) + } else { + None + } + } + Input::WhipMedia(owner, media) => { + log::trace!("WhipMedia: {:?}, {}", owner, media.seq_no); + let channel_id = self.whips.get(&owner)?; + let buf = media.to_buffer(); + Some(Output::Pubsub(pubsub::Control((*channel_id).into(), pubsub::ChannelControl::PubData(buf)))) + } + Input::WhepStart(owner, room) => { + log::info!("WhepStart: {:?}, {:?}", owner, room); + let channel_id = room_channel(&room); + self.wheps.insert(owner, channel_id); + let channel = self.channels.entry(channel_id).or_insert(Channel { whips: Vec::new(), wheps: Vec::new() }); + channel.wheps.push(owner); + if channel.wheps.len() == 1 { + Some(Output::Pubsub(pubsub::Control(channel_id.into(), pubsub::ChannelControl::SubAuto))) + } else { + None + } + } + Input::WhepStop(owner) => { + log::info!("WhepStop: {:?}", owner); + let channel_id = self.wheps.remove(&owner)?; + let channel = self.channels.get_mut(&channel_id)?; + channel.wheps.retain(|&o| o != owner); + if channel.wheps.is_empty() { + Some(Output::Pubsub(pubsub::Control(channel_id.into(), pubsub::ChannelControl::UnsubAuto))) + } else { + None + } + } + Input::WhepRequest(owner, kind) => { + let kind = match kind { + KeyframeRequestKind::Pli => 0, + KeyframeRequestKind::Fir => 1, + }; + let channel_id = self.wheps.get(&owner)?; + Some(Output::Pubsub(pubsub::Control( + (*channel_id).into(), + ChannelControl::FeedbackAuto(Feedback::simple(kind, 1, 1000, 2000)), + ))) + } + } + } + + pub fn pop_output(&mut self, now: Instant) -> Option { + None + } +} diff --git a/examples/whip-whep/src/sfu/media.rs b/examples/whip-whep/src/sfu/media.rs new file mode 100644 index 00000000..006e5d66 --- /dev/null +++ b/examples/whip-whep/src/sfu/media.rs @@ -0,0 +1,31 @@ +use serde::{Deserialize, Serialize}; +use str0m::rtp::RtpPacket; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TrackMedia { + pub seq_no: u64, + pub pt: u8, + pub time: u32, + pub marker: bool, + pub payload: Vec, +} + +impl TrackMedia { + pub fn from_raw(rtp: RtpPacket) -> Self { + Self { + seq_no: *rtp.seq_no, + pt: *rtp.header.payload_type, + time: rtp.header.timestamp, + marker: rtp.header.marker, + payload: rtp.payload, + } + } + + pub fn to_buffer(&self) -> Vec { + bincode::serialize(self).expect("") + } + + pub fn from_buffer(data: &[u8]) -> Self { + bincode::deserialize(data).expect("") + } +} diff --git a/examples/whip-whep/src/sfu/mod.rs b/examples/whip-whep/src/sfu/mod.rs new file mode 100644 index 00000000..0d6be5e1 --- /dev/null +++ b/examples/whip-whep/src/sfu/mod.rs @@ -0,0 +1,388 @@ +use std::{ + collections::{HashMap, VecDeque}, + net::SocketAddr, + time::Instant, +}; + +use sans_io_runtime::{group_owner_type, group_task, Buffer, TaskSwitcher}; + +use atm0s_sdn::features::pubsub; +use str0m::change::DtlsCert; + +use crate::http::{HttpRequest, HttpResponse}; + +use self::{ + cluster::ClusterLogic, + media::TrackMedia, + shared_port::SharedUdpPort, + whep::{WhepInput, WhepOutput, WhepTask}, + whip::{WhipInput, WhipOutput, WhipTask}, +}; + +mod cluster; +mod media; +mod shared_port; +mod whep; +mod whip; + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct SfuChannel; + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +enum TaskId { + Whip(usize), + Whep(usize), +} + +#[derive(Debug, Clone)] +pub enum Input<'a> { + HttpRequest(HttpRequest), + PubsubEvent(pubsub::Event), + UdpBind { addr: SocketAddr }, + UdpPacket { from: std::net::SocketAddr, data: Buffer<'a> }, +} + +#[derive(Debug, Clone)] +pub enum Output { + HttpResponse(HttpResponse), + PubsubControl(pubsub::Control), + UdpPacket { to: std::net::SocketAddr, data: Vec }, + Continue, +} + +group_owner_type!(WhipOwner); +group_task!(WhipTaskGroup, WhipTask, WhipInput<'a>, WhipOutput); + +group_owner_type!(WhepOwner); +group_task!(WhepTaskGroup, WhepTask, WhepInput<'a>, WhepOutput); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct SfuOwner; + +pub struct SfuWorker { + worker: u16, + dtls_cert: DtlsCert, + cluster: ClusterLogic, + whip_group: WhipTaskGroup, + whep_group: WhepTaskGroup, + output: VecDeque, + shared_udp: SharedUdpPort, + switcher: TaskSwitcher, +} + +impl SfuWorker { + fn process_req(&mut self, req: HttpRequest) { + match req.path.as_str() { + "/whip/endpoint" => self.connect_whip(req), + "/whep/endpoint" => self.connect_whep(req), + _ => { + self.output.push_back(Output::HttpResponse(HttpResponse { + req_id: req.req_id, + status: 404, + headers: HashMap::new(), + body: b"Task Not Found".to_vec(), + })); + } + } + } + + fn connect_whip(&mut self, req: HttpRequest) { + let http_auth = req.http_auth(); + log::info!("Whip endpoint connect request: {}", http_auth); + let room = http_auth; + let task = WhipTask::build(self.shared_udp.get_backend_addr().expect(""), self.dtls_cert.clone(), room, &String::from_utf8_lossy(&req.body)); + match task { + Ok(task) => { + log::info!("Whip endpoint created {}", task.ice_ufrag); + let index = self.whip_group.add_task(task.task); + self.shared_udp.add_ufrag(task.ice_ufrag, TaskId::Whip(index)); + self.output.push_back(Output::HttpResponse(HttpResponse { + req_id: req.req_id, + status: 200, + headers: HashMap::from([ + ("Content-Type".to_string(), "application/sdp".to_string()), + ("Location".to_string(), format!("/whip/endpoint/{}/{index}", self.worker)), + ]), + body: task.sdp.into_bytes(), + })); + } + Err(err) => { + log::error!("Error creating whip endpoint: {}", err); + self.output.push_back(Output::HttpResponse(HttpResponse { + req_id: req.req_id, + status: 500, + headers: HashMap::new(), + body: err.into_bytes(), + })); + } + } + } + + fn connect_whep(&mut self, req: HttpRequest) { + let http_auth = req.http_auth(); + log::info!("Whep endpoint connect request: {}", http_auth); + let room = http_auth; + let task = WhepTask::build(self.shared_udp.get_backend_addr().expect(""), self.dtls_cert.clone(), room, &String::from_utf8_lossy(&req.body)); + match task { + Ok(task) => { + log::info!("Whep endpoint created {}", task.ice_ufrag); + let index = self.whep_group.add_task(task.task); + self.shared_udp.add_ufrag(task.ice_ufrag, TaskId::Whep(index)); + self.output.push_back(Output::HttpResponse(HttpResponse { + req_id: req.req_id, + status: 200, + headers: HashMap::from([ + ("Content-Type".to_string(), "application/sdp".to_string()), + ("Location".to_string(), format!("/whep/endpoint/{}/{index}", self.worker)), + ]), + body: task.sdp.into_bytes(), + })); + } + Err(err) => { + log::error!("Error creating whep endpoint: {}", err); + self.output.push_back(Output::HttpResponse(HttpResponse { + req_id: req.req_id, + status: 500, + headers: HashMap::new(), + body: err.into_bytes(), + })); + } + } + } +} + +#[repr(u8)] +enum TaskType { + Cluster = 0, + Whip = 1, + Whep = 2, +} + +impl From for TaskType { + fn from(value: usize) -> Self { + match value { + 0 => Self::Cluster, + 1 => Self::Whip, + 2 => Self::Whep, + _ => panic!("Should not happen"), + } + } +} + +impl SfuWorker { + fn process_cluster_output<'a>(&mut self, now: Instant, out: cluster::Output) -> Output { + self.switcher.queue_flag_task(TaskType::Cluster as usize); + match out { + cluster::Output::Pubsub(control) => Output::PubsubControl(control), + cluster::Output::WhepMedia(owners, media) => { + for owner in owners { + if let Some(out) = self.whep_group.on_event(now, owner.index(), WhepInput::Media(&media)) { + let out = self.process_whep_out(now, owner.index(), out); + self.output.push_back(out); + } + } + Output::Continue + } + cluster::Output::WhipControl(owners, kind) => { + for owner in owners { + if let Some(out) = self.whip_group.on_event(now, owner.index(), WhipInput::KeyFrame(kind)) { + let out = self.process_whip_out(now, owner.index(), out); + self.output.push_back(out); + } + } + Output::Continue + } + } + } + + fn process_whip_out<'a>(&mut self, now: Instant, index: usize, out: WhipOutput) -> Output { + self.switcher.queue_flag_task(TaskType::Whip as usize); + match out { + WhipOutput::Started(room) => { + if let Some(out) = self.cluster.on_input(now, cluster::Input::WhipStart(WhipOwner(index), room)) { + self.process_cluster_output(now, out) + } else { + Output::Continue + } + } + WhipOutput::Media(media) => { + if let Some(out) = self.cluster.on_input(now, cluster::Input::WhipMedia(WhipOwner(index), media)) { + self.process_cluster_output(now, out) + } else { + Output::Continue + } + } + WhipOutput::UdpPacket { to, data } => Output::UdpPacket { to, data }, + WhipOutput::Destroy => { + self.shared_udp.remove_task(TaskId::Whip(index)); + self.whip_group.remove_task(index); + log::info!("destroy whip({index}) => remain {}", self.whip_group.tasks()); + if let Some(out) = self.cluster.on_input(now, cluster::Input::WhipStop(WhipOwner(index))) { + self.process_cluster_output(now, out) + } else { + Output::Continue + } + } + } + } + + fn process_whep_out<'a>(&mut self, now: Instant, index: usize, out: WhepOutput) -> Output { + self.switcher.queue_flag_task(TaskType::Whep as usize); + match out { + WhepOutput::Started(room) => { + if let Some(out) = self.cluster.on_input(now, cluster::Input::WhepStart(WhepOwner(index), room)) { + self.process_cluster_output(now, out) + } else { + Output::Continue + } + } + WhepOutput::RequestKey(kind) => { + if let Some(out) = self.cluster.on_input(now, cluster::Input::WhepRequest(WhepOwner(index), kind)) { + self.process_cluster_output(now, out) + } else { + Output::Continue + } + } + WhepOutput::UdpPacket { to, data } => Output::UdpPacket { to, data }, + WhepOutput::Destroy => { + self.shared_udp.remove_task(TaskId::Whip(index)); + self.whep_group.remove_task(index); + log::info!("destroy whep({index}) => remain {}", self.whep_group.tasks()); + if let Some(out) = self.cluster.on_input(now, cluster::Input::WhepStop(WhepOwner(index))) { + self.process_cluster_output(now, out) + } else { + Output::Continue + } + } + } + } +} + +impl SfuWorker { + pub fn build(worker: u16) -> Self { + Self { + worker, + dtls_cert: DtlsCert::new_openssl(), + cluster: ClusterLogic::default(), + whip_group: WhipTaskGroup::default(), + whep_group: WhepTaskGroup::default(), + shared_udp: SharedUdpPort::default(), + switcher: TaskSwitcher::new(3), + output: VecDeque::new(), + } + } + pub fn worker_index(&self) -> u16 { + self.worker + } + pub fn tasks(&self) -> usize { + self.whip_group.tasks() + self.whep_group.tasks() + } + pub fn on_tick<'a>(&mut self, now: Instant) -> Option { + if let Some(e) = self.output.pop_front() { + return Some(e.into()); + } + + let switcher = &mut self.switcher; + loop { + match switcher.looper_current(now)?.into() { + TaskType::Cluster => { + if let Some(out) = switcher.looper_process(self.cluster.on_tick(now)) { + return Some(self.process_cluster_output(now, out)); + } + } + TaskType::Whip => { + if let Some((index, out)) = switcher.looper_process(self.whip_group.on_tick(now)) { + return Some(self.process_whip_out(now, index, out)); + } + } + TaskType::Whep => { + if let Some((index, out)) = switcher.looper_process(self.whep_group.on_tick(now)) { + return Some(self.process_whep_out(now, index, out)); + } + } + } + } + } + + pub fn on_event<'a>(&mut self, now: Instant, input: Input) -> Option { + match input { + Input::UdpBind { addr } => { + log::info!("UdpBind: {}", addr); + self.shared_udp.set_backend_info(addr); + None + } + Input::UdpPacket { from, data } => match self.shared_udp.map_remote(from, &data) { + Some(TaskId::Whip(index)) => { + let out = self.whip_group.on_event(now, index, WhipInput::UdpPacket { from, data })?; + Some(self.process_whip_out(now, index, out)) + } + Some(TaskId::Whep(index)) => { + let out = self.whep_group.on_event(now, index, WhepInput::UdpPacket { from, data })?; + Some(self.process_whep_out(now, index, out)) + } + None => { + log::debug!("Unknown remote address: {}", from); + None + } + }, + Input::HttpRequest(req) => { + self.process_req(req); + self.output.pop_front() + } + Input::PubsubEvent(event) => { + if let Some(out) = self.cluster.on_input(now, cluster::Input::Pubsub(event)) { + Some(self.process_cluster_output(now, out)) + } else { + None + } + } + } + } + + pub fn pop_output<'a>(&mut self, now: Instant) -> Option { + let switcher = &mut self.switcher; + while let Some(current) = switcher.queue_current() { + match current.into() { + TaskType::Cluster => { + if let Some(out) = switcher.queue_process(self.cluster.pop_output(now)) { + return Some(self.process_cluster_output(now, out)); + } + } + TaskType::Whip => { + if let Some((index, out)) = switcher.queue_process(self.whip_group.pop_output(now)) { + return Some(self.process_whip_out(now, index, out)); + } + } + TaskType::Whep => { + if let Some((index, out)) = switcher.queue_process(self.whep_group.pop_output(now)) { + return Some(self.process_whep_out(now, index, out)); + } + } + } + } + None + } + + pub fn shutdown<'a>(&mut self, now: Instant) -> Option { + let switcher = &mut self.switcher; + loop { + match switcher.looper_current(now)?.into() { + TaskType::Cluster => { + if let Some(out) = switcher.looper_process(None) { + return Some(self.process_cluster_output(now, out)); + } + } + TaskType::Whip => { + if let Some((index, out)) = switcher.looper_process(self.whip_group.shutdown(now)) { + return Some(self.process_whip_out(now, index, out)); + } + } + TaskType::Whep => { + if let Some((index, out)) = switcher.looper_process(self.whep_group.shutdown(now)) { + return Some(self.process_whep_out(now, index, out)); + } + } + } + } + } +} diff --git a/examples/whip-whep/src/sfu/shared_port.rs b/examples/whip-whep/src/sfu/shared_port.rs new file mode 100644 index 00000000..990ef282 --- /dev/null +++ b/examples/whip-whep/src/sfu/shared_port.rs @@ -0,0 +1,72 @@ +use faster_stun::attribute::*; +use faster_stun::*; +use std::{collections::HashMap, fmt::Debug, hash::Hash, net::SocketAddr}; + +#[derive(Debug)] +pub struct SharedUdpPort { + backend_addr: Option, + task_remotes: HashMap, + task_remotes_map: HashMap>, + task_ufrags: HashMap, + task_ufrags_reverse: HashMap, +} + +impl Default for SharedUdpPort { + fn default() -> Self { + Self { + backend_addr: None, + task_remotes: HashMap::new(), + task_remotes_map: HashMap::new(), + task_ufrags: HashMap::new(), + task_ufrags_reverse: HashMap::new(), + } + } +} + +impl SharedUdpPort { + pub fn set_backend_info(&mut self, addr: SocketAddr) { + self.backend_addr = Some(addr); + } + + pub fn get_backend_addr(&self) -> Option { + self.backend_addr + } + + pub fn add_ufrag(&mut self, ufrag: String, task: Task) { + log::info!("Add ufrag {} to task {:?}", ufrag, task); + self.task_ufrags.insert(ufrag.clone(), task); + self.task_ufrags_reverse.insert(task, ufrag); + } + + pub fn remove_task(&mut self, task: Task) -> Option<()> { + let ufrag = self.task_ufrags_reverse.remove(&task)?; + log::info!("Remove task {:?} => ufrag {}", task, ufrag); + self.task_ufrags.remove(&ufrag)?; + let remotes = self.task_remotes_map.remove(&task)?; + for remote in remotes { + log::info!(" Remove remote {:?} => task {:?}", remote, task); + self.task_remotes.remove(&remote); + } + Some(()) + } + + pub fn map_remote(&mut self, remote: SocketAddr, buf: &[u8]) -> Option { + if let Some(task) = self.task_remotes.get(&remote) { + return Some(*task); + } + + let stun_username = Self::get_stun_username(buf)?; + log::warn!("Received a stun packet from an unknown remote: {:?}, username {}", remote, stun_username); + let task = self.task_ufrags.get(stun_username)?; + log::info!("Mapping remote {:?} to task {:?}", remote, task); + self.task_remotes.insert(remote, *task); + self.task_remotes_map.entry(*task).or_default().push(remote); + Some(*task) + } + + fn get_stun_username(buf: &[u8]) -> Option<&str> { + let mut attributes = Vec::new(); + let message = MessageReader::decode(buf, &mut attributes).ok()?; + message.get::().map(|u| u.split(':').next())? + } +} diff --git a/examples/whip-whep/src/sfu/whep.rs b/examples/whip-whep/src/sfu/whep.rs new file mode 100644 index 00000000..bb4e68ca --- /dev/null +++ b/examples/whip-whep/src/sfu/whep.rs @@ -0,0 +1,197 @@ +use std::{net::SocketAddr, time::Instant}; + +use sans_io_runtime::Buffer; +use str0m::{ + change::{DtlsCert, SdpOffer}, + ice::IceCreds, + media::{KeyframeRequestKind, MediaKind, Mid}, + net::{Protocol, Receive}, + Candidate, Event as Str0mEvent, IceConnectionState, Input, Output, Rtc, +}; + +use super::TrackMedia; + +pub struct WhepTaskBuildResult { + pub task: WhepTask, + pub ice_ufrag: String, + pub sdp: String, +} + +pub enum WhepInput<'a> { + UdpPacket { from: SocketAddr, data: Buffer<'a> }, + Media(&'a TrackMedia), +} + +pub enum WhepOutput { + UdpPacket { to: SocketAddr, data: Vec }, + Started(String), + RequestKey(KeyframeRequestKind), + Destroy, +} + +pub struct WhepTask { + backend_addr: SocketAddr, + timeout: Option, + rtc: Rtc, + audio_mid: Option, + video_mid: Option, + room: String, +} + +impl WhepTask { + pub fn build(backend_addr: SocketAddr, dtls_cert: DtlsCert, room: String, sdp: &str) -> Result { + let rtc_config = Rtc::builder().set_rtp_mode(true).set_ice_lite(true).set_dtls_cert(dtls_cert).set_local_ice_credentials(IceCreds::new()); + + let ice_ufrag = rtc_config.local_ice_credentials().as_ref().expect("should have ice credentials").ufrag.clone(); + + let mut rtc = rtc_config.build(); + rtc.direct_api().enable_twcc_feedback(); + + rtc.add_local_candidate(Candidate::host(backend_addr, Protocol::Udp).expect("Should create candidate")); + + let offer = SdpOffer::from_sdp_string(&sdp).expect("Should parse offer"); + let answer = rtc.sdp_api().accept_offer(offer).expect("Should accept offer"); + let instance = Self { + backend_addr, + timeout: None, + rtc, + audio_mid: None, + video_mid: None, + room, + }; + + Ok(WhepTaskBuildResult { + task: instance, + ice_ufrag, + sdp: answer.to_sdp_string(), + }) + } + + fn pop_event_inner(&mut self, now: Instant, has_input: bool) -> Option { + // incase we have input, we should not check timeout + if !has_input { + if let Some(timeout) = self.timeout { + if timeout > now { + return None; + } + } + } + + while let Ok(out) = self.rtc.poll_output() { + match out { + Output::Timeout(timeout) => { + self.timeout = Some(timeout); + return None; + } + Output::Transmit(send) => { + return Some(WhepOutput::UdpPacket { + to: send.destination, + data: send.contents.to_vec(), + }); + } + Output::Event(e) => match e { + Str0mEvent::Connected => { + log::info!("WhepServerTask connected"); + return WhepOutput::Started(self.room.clone()).into(); + } + Str0mEvent::MediaAdded(media) => { + log::info!("WhepServerTask media added: {:?}", media); + if media.kind == MediaKind::Audio { + self.audio_mid = Some(media.mid); + } else { + self.video_mid = Some(media.mid); + } + } + Str0mEvent::IceConnectionStateChange(state) => match state { + IceConnectionState::Disconnected => { + return WhepOutput::Destroy.into(); + } + _ => {} + }, + Str0mEvent::KeyframeRequest(req) => { + return Some(WhepOutput::RequestKey(req.kind)); + } + _ => {} + }, + } + } + + None + } +} + +impl WhepTask { + /// Called on each tick of the task. + pub fn on_tick<'a>(&mut self, now: Instant) -> Option { + let timeout = self.timeout?; + if now < timeout { + return None; + } + + if let Err(e) = self.rtc.handle_input(Input::Timeout(now)) { + log::error!("Error handling timeout: {}", e); + } + log::trace!("clear timeout after handled"); + self.timeout = None; + self.pop_event_inner(now, true) + } + + /// Called when an input event is received for the task. + pub fn on_event<'a>(&mut self, now: Instant, input: WhepInput<'a>) -> Option { + match input { + WhepInput::UdpPacket { from, data } => { + if let Err(e) = self + .rtc + .handle_input(Input::Receive(now, Receive::new(Protocol::Udp, from, self.backend_addr, &data).expect("Should parse udp"))) + { + log::error!("Error handling udp: {}", e); + } + self.timeout = None; + self.pop_event_inner(now, true) + } + WhepInput::Media(media) => { + let (mid, nackable) = if media.pt == 111 { + (self.audio_mid, false) + } else { + (self.video_mid, true) + }; + + if let Some(mid) = mid { + if let Some(stream) = self.rtc.direct_api().stream_tx_by_mid(mid, None) { + log::debug!("Write rtp for mid: {:?} {} {} {}", mid, media.seq_no, media.time, media.payload.len()); + if let Err(e) = stream.write_rtp( + media.pt.into(), + media.seq_no.into(), + media.time, + Instant::now(), + media.marker, + Default::default(), + nackable, + media.payload.clone(), + ) { + log::error!("Error writing rtp: {}", e); + } + log::trace!("clear timeout with media"); + self.timeout = None; + self.pop_event_inner(now, true) + } else { + None + } + } else { + log::error!("No mid for media {}", media.pt); + None + } + } + } + } + + /// Retrieves the next output event from the task. + pub fn pop_output<'a>(&mut self, now: Instant) -> Option { + self.pop_event_inner(now, false) + } + + pub fn shutdown<'a>(&mut self, _now: Instant) -> Option { + self.rtc.disconnect(); + return WhepOutput::Destroy.into(); + } +} diff --git a/examples/whip-whep/src/sfu/whip.rs b/examples/whip-whep/src/sfu/whip.rs new file mode 100644 index 00000000..6676f462 --- /dev/null +++ b/examples/whip-whep/src/sfu/whip.rs @@ -0,0 +1,173 @@ +use std::{net::SocketAddr, time::Instant}; + +use sans_io_runtime::Buffer; +use str0m::{ + change::{DtlsCert, SdpOffer}, + ice::IceCreds, + media::{KeyframeRequestKind, MediaKind, Mid}, + net::{Protocol, Receive}, + Candidate, Event as Str0mEvent, IceConnectionState, Input, Output, Rtc, +}; + +use super::TrackMedia; + +pub struct WhipTaskBuildResult { + pub task: WhipTask, + pub ice_ufrag: String, + pub sdp: String, +} + +pub enum WhipInput<'a> { + UdpPacket { from: SocketAddr, data: Buffer<'a> }, + KeyFrame(KeyframeRequestKind), +} + +pub enum WhipOutput { + UdpPacket { to: SocketAddr, data: Vec }, + Started(String), + Media(TrackMedia), + Destroy, +} + +pub struct WhipTask { + backend_addr: SocketAddr, + timeout: Option, + rtc: Rtc, + audio_mid: Option, + video_mid: Option, + room: String, +} + +impl WhipTask { + pub fn build(backend_addr: SocketAddr, dtls_cert: DtlsCert, room: String, sdp: &str) -> Result { + let rtc_config = Rtc::builder().set_rtp_mode(true).set_ice_lite(true).set_dtls_cert(dtls_cert).set_local_ice_credentials(IceCreds::new()); + + let ice_ufrag = rtc_config.local_ice_credentials().as_ref().expect("should have ice credentials").ufrag.clone(); + let mut rtc = rtc_config.build(); + rtc.direct_api().enable_twcc_feedback(); + + rtc.add_local_candidate(Candidate::host(backend_addr, Protocol::Udp).expect("Should create candidate")); + + let offer = SdpOffer::from_sdp_string(&sdp).expect("Should parse offer"); + let answer = rtc.sdp_api().accept_offer(offer).expect("Should accept offer"); + let instance = Self { + backend_addr, + timeout: None, + rtc, + audio_mid: None, + video_mid: None, + room, + }; + + Ok(WhipTaskBuildResult { + task: instance, + ice_ufrag, + sdp: answer.to_sdp_string(), + }) + } + + fn pop_event_inner(&mut self, now: Instant, has_input: bool) -> Option { + // incase we have input, we should not check timeout + if !has_input { + if let Some(timeout) = self.timeout { + if timeout > now { + return None; + } + } + } + + while let Ok(out) = self.rtc.poll_output() { + match out { + Output::Timeout(timeout) => { + self.timeout = Some(timeout); + break; + } + Output::Transmit(send) => { + return WhipOutput::UdpPacket { + to: send.destination, + data: send.contents.to_vec(), + } + .into(); + } + Output::Event(e) => match e { + Str0mEvent::Connected => { + log::info!("WhipServerTask connected"); + return WhipOutput::Started(self.room.clone()).into(); + } + Str0mEvent::MediaAdded(media) => { + log::info!("WhipServerTask media added: {:?}", media); + if media.kind == MediaKind::Audio { + self.audio_mid = Some(media.mid); + } else { + self.video_mid = Some(media.mid); + } + } + Str0mEvent::IceConnectionStateChange(state) => match state { + IceConnectionState::Disconnected => { + return WhipOutput::Destroy.into(); + } + _ => {} + }, + Str0mEvent::RtpPacket(rtp) => { + let media = TrackMedia::from_raw(rtp); + return Some(WhipOutput::Media(media)); + } + _ => {} + }, + } + } + + None + } +} + +impl WhipTask { + /// Called on each tick of the task. + pub fn on_tick<'a>(&mut self, now: Instant) -> Option { + let timeout = self.timeout?; + if now < timeout { + return None; + } + + if let Err(e) = self.rtc.handle_input(Input::Timeout(now)) { + log::error!("Error handling timeout: {}", e); + } + self.timeout = None; + self.pop_event_inner(now, true) + } + + /// Called when an input event is received for the task. + pub fn on_event<'a>(&mut self, now: Instant, input: WhipInput<'a>) -> Option { + match input { + WhipInput::UdpPacket { from, data } => { + if let Err(e) = self + .rtc + .handle_input(Input::Receive(now, Receive::new(Protocol::Udp, from, self.backend_addr, &data).expect("Should parse udp"))) + { + log::error!("Error handling udp: {}", e); + } + self.pop_event_inner(now, true) + } + WhipInput::KeyFrame(kind) => { + if let Some(mid) = self.video_mid { + log::info!("Requesting keyframe for video mid: {:?}", mid); + self.rtc.direct_api().stream_rx_by_mid(mid, None).expect("Should has video mid").request_keyframe(kind); + self.pop_event_inner(now, true) + } else { + log::error!("No video mid for requesting keyframe"); + None + } + } + } + } + + /// Retrieves the next output event from the task. + pub fn pop_output<'a>(&mut self, now: Instant) -> Option { + self.pop_event_inner(now, false) + } + + pub fn shutdown<'a>(&mut self, _now: Instant) -> Option { + self.rtc.disconnect(); + WhipOutput::Destroy.into() + } +} diff --git a/examples/whip-whep/src/worker.rs b/examples/whip-whep/src/worker.rs new file mode 100644 index 00000000..5dc58e50 --- /dev/null +++ b/examples/whip-whep/src/worker.rs @@ -0,0 +1,365 @@ +use std::{collections::VecDeque, net::SocketAddr, sync::Arc, time::Instant}; + +use atm0s_sdn::{ + base::{Authorization, HandshakeBuilder, ServiceBuilder}, + convert_enum, + features::{FeaturesControl, FeaturesEvent}, + services::visualization, + ControllerPlaneCfg, DataPlaneCfg, NetInput, NetOutput, NodeId, SdnChannel, SdnEvent, SdnExtIn, SdnExtOut, SdnOwner, SdnWorker, SdnWorkerBusEvent, SdnWorkerCfg, SdnWorkerInput, SdnWorkerOutput, + ShadowRouterHistory, TimePivot, +}; +use rand::rngs::ThreadRng; +use sans_io_runtime::{ + backend::{BackendIncoming, BackendOutgoing}, + BusChannelControl, BusControl, BusEvent, TaskSwitcher, WorkerInner, WorkerInnerInput, WorkerInnerOutput, +}; + +use crate::{ + http::{HttpRequest, HttpResponse}, + sfu::{self, SfuChannel, SfuOwner, SfuWorker}, +}; + +#[repr(usize)] +enum TaskType { + Sfu = 0, + Sdn = 1, +} + +impl TryFrom for TaskType { + type Error = (); + fn try_from(value: usize) -> Result { + match value { + 0 => Ok(Self::Sfu), + 1 => Ok(Self::Sdn), + _ => Err(()), + } + } +} + +#[derive(convert_enum::From, convert_enum::TryInto, Clone, Debug)] +pub enum ExtIn { + Sdn(SdnExtIn), + HttpRequest(HttpRequest), +} + +#[derive(convert_enum::From, convert_enum::TryInto, Clone)] +pub enum ExtOut { + Sdn(SdnExtOut), + HttpResponse(HttpResponse), +} + +#[derive(convert_enum::From, convert_enum::TryInto, Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ChannelId { + Sfu(SfuChannel), + Sdn(SdnChannel), +} + +#[derive(convert_enum::From, convert_enum::TryInto, Clone, Debug)] +pub enum Event { + Sdn(SdnEvent), +} + +pub struct ControllerCfg { + pub session: u64, + pub auth: Arc, + pub handshake: Arc, + #[cfg(feature = "vpn")] + pub vpn_tun_device: Option, +} + +pub struct SdnInnerCfg { + pub node_id: NodeId, + pub tick_ms: u64, + pub udp_port: u16, + pub controller: Option, + pub services: Vec>>, + pub history: Arc, + #[cfg(feature = "vpn")] + pub vpn_tun_fd: Option, +} + +pub struct ICfg { + pub sdn: SdnInnerCfg, + pub sdn_listen: SocketAddr, + pub sfu: SocketAddr, +} + +#[derive(convert_enum::From, convert_enum::TryInto)] +pub enum SCfg {} + +pub type SC = visualization::Control; +pub type SE = visualization::Event; +pub type TC = (); +pub type TW = (); + +#[derive(convert_enum::From, Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum RunnerOwner { + Sdn(SdnOwner), + Sfu(SfuOwner), +} + +pub struct RunnerWorker { + worker: u16, + sdn: SdnWorker, + sfu: SfuWorker, + sdn_backend_slot: usize, + sfu_backend_slot: usize, + switcher: TaskSwitcher, + time: TimePivot, + queue: VecDeque>, +} + +impl WorkerInner for RunnerWorker { + fn build(worker: u16, cfg: ICfg) -> Self { + let mut queue = VecDeque::from([ + WorkerInnerOutput::Net(RunnerOwner::Sdn(SdnOwner), BackendOutgoing::UdpListen { addr: cfg.sdn_listen, reuse: true }), + WorkerInnerOutput::Net(RunnerOwner::Sfu(SfuOwner), BackendOutgoing::UdpListen { addr: cfg.sfu, reuse: false }), + WorkerInnerOutput::Bus(BusControl::Channel( + RunnerOwner::Sdn(SdnOwner), + BusChannelControl::Subscribe(ChannelId::Sdn(SdnChannel::Worker(worker))), + )), + ]); + + if cfg.sdn.controller.is_some() { + queue.push_back(WorkerInnerOutput::Bus(BusControl::Channel( + RunnerOwner::Sdn(SdnOwner), + BusChannelControl::Subscribe(ChannelId::Sdn(SdnChannel::Controller)), + ))); + } + + Self { + worker, + sdn: SdnWorker::new(SdnWorkerCfg { + node_id: cfg.sdn.node_id, + tick_ms: cfg.sdn.tick_ms, + controller: cfg.sdn.controller.map(|c| ControllerPlaneCfg { + session: c.session, + services: cfg.sdn.services.clone(), + authorization: c.auth, + handshake_builder: c.handshake, + random: Box::new(ThreadRng::default()), + }), + data: DataPlaneCfg { + worker_id: 0, + services: cfg.sdn.services.clone(), + history: cfg.sdn.history.clone(), + }, + }), + sfu: SfuWorker::build(worker), + sfu_backend_slot: 0, + sdn_backend_slot: 0, + switcher: TaskSwitcher::new(2), + time: TimePivot::build(), + queue, + } + } + + fn worker_index(&self) -> u16 { + self.worker + } + + fn tasks(&self) -> usize { + self.sdn.tasks() + self.sfu.tasks() + } + + fn spawn(&mut self, now: Instant, _cfg: SCfg) { + unimplemented!() + } + + fn on_tick<'a>(&mut self, now: Instant) -> Option> { + let s = &mut self.switcher; + while let Some(current) = s.looper_current(now) { + match current.try_into().ok()? { + TaskType::Sdn => { + let now_ms = self.time.timestamp_ms(now); + if let Some(out) = s.looper_process(self.sdn.on_tick(now_ms)) { + return self.process_sdn(now, out); + } + } + TaskType::Sfu => { + if let Some(out) = s.looper_process(self.sfu.on_tick(now)) { + return self.process_sfu(now, out); + } + } + } + } + + self.queue.pop_front() + } + + fn on_event<'a>(&mut self, now: Instant, event: WorkerInnerInput<'a, RunnerOwner, ExtIn, ChannelId, Event>) -> Option> { + match event { + WorkerInnerInput::Net(owner, event) => match owner { + RunnerOwner::Sdn(_owner) => { + let now_ms = self.time.timestamp_ms(now); + match event { + BackendIncoming::UdpPacket { slot: _, from, data } => { + let out = self.sdn.on_event(now_ms, SdnWorkerInput::Net(NetInput::UdpPacket(from, data)))?; + self.switcher.queue_flag_task(TaskType::Sdn as usize); + self.process_sdn(now, out) + } + BackendIncoming::UdpListenResult { bind: _, result } => { + log::info!("Sdn listen result: {:?}", result); + self.sdn_backend_slot = result.ok()?.1; + None + } + } + } + RunnerOwner::Sfu(_owner) => match event { + BackendIncoming::UdpPacket { slot, from, data } => { + let out = self.sfu.on_event(now, sfu::Input::UdpPacket { from, data: data.freeze() })?; + self.switcher.queue_flag_task(TaskType::Sfu as usize); + self.process_sfu(now, out) + } + BackendIncoming::UdpListenResult { bind, result } => { + log::info!("Sfu listen result: {:?}", result); + let (addr, slot) = result.ok()?; + self.sfu_backend_slot = slot; + let out = self.sfu.on_event(now, sfu::Input::UdpBind { addr })?; + self.switcher.queue_flag_task(TaskType::Sfu as usize); + self.process_sfu(now, out) + } + }, + }, + WorkerInnerInput::Ext(ext) => match ext { + ExtIn::Sdn(ext) => { + let now_ms = self.time.timestamp_ms(now); + let out = self.sdn.on_event(now_ms, SdnWorkerInput::Ext(ext))?; + self.switcher.queue_flag_task(TaskType::Sdn as usize); + self.process_sdn(now, out) + } + ExtIn::HttpRequest(req) => { + let out = self.sfu.on_event(now, sfu::Input::HttpRequest(req))?; + self.switcher.queue_flag_task(TaskType::Sfu as usize); + self.process_sfu(now, out) + } + }, + WorkerInnerInput::Bus(event) => match event { + BusEvent::Broadcast(_from, Event::Sdn(event)) => { + let now_ms = self.time.timestamp_ms(now); + let out = self.sdn.on_event(now_ms, SdnWorkerInput::Bus(event))?; + self.process_sdn(now, out) + } + BusEvent::Channel(_owner, _channel, Event::Sdn(event)) => { + let now_ms = self.time.timestamp_ms(now); + let out = self.sdn.on_event(now_ms, SdnWorkerInput::Bus(event))?; + self.process_sdn(now, out) + } + }, + } + } + + fn pop_output<'a>(&mut self, now: Instant) -> Option> { + let s = &mut self.switcher; + while let Some(current) = s.queue_current() { + match current.try_into().ok()? { + TaskType::Sdn => { + let now_ms = self.time.timestamp_ms(now); + if let Some(out) = s.queue_process(self.sdn.pop_output(now_ms)) { + return self.process_sdn(now, out); + } + } + TaskType::Sfu => { + if let Some(out) = s.queue_process(self.sfu.pop_output(now)) { + return self.process_sfu(now, out); + } + } + } + } + + self.queue.pop_front() + } + + fn shutdown<'a>(&mut self, now: Instant) -> Option> { + let s = &mut self.switcher; + while let Some(current) = s.looper_current(now) { + match current.try_into().ok()? { + TaskType::Sdn => { + let now_ms = self.time.timestamp_ms(now); + if let Some(out) = s.looper_process(self.sdn.on_event(now_ms, SdnWorkerInput::ShutdownRequest)) { + return self.process_sdn(now, out); + } + } + TaskType::Sfu => { + if let Some(out) = s.looper_process(self.sfu.shutdown(now)) { + return self.process_sfu(now, out); + } + } + } + } + + None + } +} + +impl RunnerWorker { + fn process_sdn<'a>(&mut self, now: Instant, out: SdnWorkerOutput<'a, SC, SE, TC, TW>) -> Option> { + self.switcher.queue_flag_task(TaskType::Sdn as usize); + match out { + SdnWorkerOutput::Ext(ext) => Some(WorkerInnerOutput::Ext(true, ExtOut::Sdn(ext))), + SdnWorkerOutput::ExtWorker(event) => match event { + SdnExtOut::FeaturesEvent(event) => { + if let FeaturesEvent::PubSub(event) = event { + let out = self.sfu.on_event(now, sfu::Input::PubsubEvent(event))?; + self.process_sfu(now, out) + } else { + None + } + } + SdnExtOut::ServicesEvent(service, event) => None, + }, + SdnWorkerOutput::Net(out) => match out { + NetOutput::UdpPacket(remote, data) => Some(WorkerInnerOutput::Net( + RunnerOwner::Sdn(SdnOwner), + BackendOutgoing::UdpPacket { + slot: self.sdn_backend_slot, + to: remote, + data, + }, + )), + NetOutput::UdpPackets(remotes, data) => Some(WorkerInnerOutput::Net( + RunnerOwner::Sdn(SdnOwner), + BackendOutgoing::UdpPackets { + slot: self.sdn_backend_slot, + to: remotes, + data, + }, + )), + }, + SdnWorkerOutput::Bus(event) => match event { + SdnWorkerBusEvent::Control(..) => Some(WorkerInnerOutput::Bus(BusControl::Channel( + RunnerOwner::Sdn(SdnOwner), + BusChannelControl::Publish(ChannelId::Sdn(SdnChannel::Controller), true, event.into()), + ))), + SdnWorkerBusEvent::Workers(..) => Some(WorkerInnerOutput::Bus(BusControl::Broadcast(true, event.into()))), + SdnWorkerBusEvent::Worker(worker, msg) => Some(WorkerInnerOutput::Bus(BusControl::Channel( + RunnerOwner::Sdn(SdnOwner), + BusChannelControl::Publish(ChannelId::Sdn(SdnChannel::Worker(worker)), true, Event::Sdn(SdnEvent::Worker(self.worker, msg))), + ))), + }, + SdnWorkerOutput::ShutdownResponse => None, + SdnWorkerOutput::Continue => None, + } + } + + fn process_sfu<'a>(&mut self, now: Instant, out: sfu::Output) -> Option> { + self.switcher.queue_flag_task(TaskType::Sfu as usize); + match out { + sfu::Output::HttpResponse(res) => Some(WorkerInnerOutput::Ext(true, ExtOut::HttpResponse(res))), + sfu::Output::PubsubControl(control) => { + let now_ms = self.time.timestamp_ms(now); + let out = self.sdn.on_event(now_ms, SdnWorkerInput::ExtWorker(SdnExtIn::FeaturesControl(FeaturesControl::PubSub(control))))?; + self.process_sdn(now, out) + } + sfu::Output::UdpPacket { to, data } => Some(WorkerInnerOutput::Net( + RunnerOwner::Sdn(SdnOwner), + BackendOutgoing::UdpPacket { + slot: self.sfu_backend_slot, + to, + data: data.into(), + }, + )), + sfu::Output::Continue => None, + } + } +} diff --git a/packages/network/Cargo.toml b/packages/network/Cargo.toml index 1615b213..b8d15e67 100644 --- a/packages/network/Cargo.toml +++ b/packages/network/Cargo.toml @@ -32,6 +32,6 @@ env_logger = { workspace = true } atm0s-sdn-router = { path = "../core/router", version = "0.1.4" } [features] -default = ["vpn", "fuzz"] +default = ["fuzz"] vpn = [] fuzz = [] diff --git a/packages/network/src/base/buf.rs b/packages/network/src/base/buf.rs deleted file mode 100644 index e0266822..00000000 --- a/packages/network/src/base/buf.rs +++ /dev/null @@ -1,409 +0,0 @@ -use std::ops::{Deref, DerefMut}; - -#[derive(Debug, Clone)] -pub enum ReadOnlyBuffer<'a> { - Ref(&'a [u8]), - Vec(Vec), -} - -impl<'a> ReadOnlyBuffer<'a> { - pub fn view<'b>(&'b self) -> ReadOnlyBuffer<'b> { - match self { - Self::Ref(r) => ReadOnlyBuffer::Ref(r), - Self::Vec(v) => ReadOnlyBuffer::Ref(v), - } - } - - pub fn owned(self) -> ReadOnlyBuffer<'static> { - match self { - Self::Ref(r) => ReadOnlyBuffer::Vec(r.to_vec()), - Self::Vec(v) => ReadOnlyBuffer::Vec(v), - } - } - - pub fn owned_mut(self) -> WriteableBuffer<'static> { - match self { - Self::Ref(r) => WriteableBuffer::Vec(r.to_vec()), - Self::Vec(v) => WriteableBuffer::Vec(v), - } - } - - pub fn clone_mut(&self) -> WriteableBuffer<'static> { - match self { - Self::Ref(r) => WriteableBuffer::Vec(r.to_vec()), - Self::Vec(v) => WriteableBuffer::Vec(v.to_vec()), - } - } -} - -impl<'a> Deref for ReadOnlyBuffer<'a> { - type Target = [u8]; - fn deref(&self) -> &Self::Target { - match self { - Self::Ref(r) => r, - Self::Vec(v) => v, - } - } -} - -impl<'a> From<&'a [u8]> for ReadOnlyBuffer<'a> { - fn from(value: &'a [u8]) -> Self { - Self::Ref(value) - } -} - -impl<'a> From> for ReadOnlyBuffer<'a> { - fn from(value: Vec) -> Self { - Self::Vec(value) - } -} - -#[derive(Debug)] -pub enum WriteableBuffer<'a> { - Ref(&'a mut [u8]), - Vec(Vec), -} - -impl<'a> WriteableBuffer<'a> { - pub fn owned(self) -> WriteableBuffer<'static> { - match self { - Self::Ref(r) => WriteableBuffer::Vec(r.to_vec()), - Self::Vec(v) => WriteableBuffer::Vec(v), - } - } - - pub fn freeze(self) -> ReadOnlyBuffer<'a> { - match self { - Self::Ref(r) => ReadOnlyBuffer::Ref(r), - Self::Vec(v) => ReadOnlyBuffer::Vec(v), - } - } - - pub fn copy_readonly(&self) -> ReadOnlyBuffer<'static> { - match self { - Self::Ref(r) => ReadOnlyBuffer::Vec(r.to_vec()), - Self::Vec(v) => ReadOnlyBuffer::Vec(v.clone()), - } - } - - pub fn view<'b>(&'b self) -> ReadOnlyBuffer<'b> { - match self { - Self::Ref(r) => ReadOnlyBuffer::Ref(r), - Self::Vec(v) => ReadOnlyBuffer::Ref(v), - } - } - - pub fn ensure_back(&mut self, more: usize) { - match self { - Self::Ref(r) => { - let mut v = Vec::with_capacity(r.len() + more); - v[0..r.len()].copy_from_slice(r); - unsafe { - v.set_len(r.len() + more); - } - } - Self::Vec(v) => { - v.reserve_exact(more); - unsafe { - v.set_len(v.capacity()); - } - } - } - } -} - -impl<'a> Deref for WriteableBuffer<'a> { - type Target = [u8]; - fn deref(&self) -> &Self::Target { - match self { - Self::Ref(r) => r, - Self::Vec(v) => v, - } - } -} - -impl<'a> DerefMut for WriteableBuffer<'a> { - fn deref_mut(&mut self) -> &mut Self::Target { - match self { - Self::Ref(r) => r, - Self::Vec(v) => v, - } - } -} - -impl<'a> From<&'a mut [u8]> for WriteableBuffer<'a> { - fn from(value: &'a mut [u8]) -> Self { - Self::Ref(value) - } -} - -impl<'a> From> for WriteableBuffer<'a> { - fn from(value: Vec) -> Self { - Self::Vec(value) - } -} - -#[derive(Debug, Clone)] -pub struct GenericBuffer<'a> { - pub buf: ReadOnlyBuffer<'a>, - pub range: std::ops::Range, -} - -impl<'a> GenericBuffer<'a> { - pub fn owned(self) -> GenericBuffer<'static> { - GenericBuffer { - buf: self.buf.owned(), - range: self.range, - } - } - - pub fn owned_mut(self) -> GenericBufferMut<'static> { - GenericBufferMut { - buf: self.buf.owned_mut(), - range: self.range, - } - } - - pub fn clone_mut(&self) -> GenericBufferMut<'static> { - GenericBufferMut { - buf: self.buf.clone_mut(), - range: self.range.clone(), - } - } - - pub fn view<'b>(&'b self, range: std::ops::Range) -> Option> { - if self.range.end - self.range.start >= range.end { - Some(GenericBuffer { - buf: self.buf.view(), - range: (self.range.start + range.start..self.range.start + range.end), - }) - } else { - None - } - } - - pub fn pop_back<'b>(&'b mut self, len: usize) -> Option> { - if self.range.end - self.range.start >= len { - self.range.end -= len; - Some(GenericBuffer { - buf: self.buf.view(), - range: (self.range.end..self.range.end + len), - }) - } else { - None - } - } - - pub fn pop_front<'b>(&'b mut self, len: usize) -> Option> { - if self.range.end - self.range.start >= len { - self.range.start += len; - Some(GenericBuffer { - buf: self.buf.view(), - range: (self.range.start - len..self.range.start), - }) - } else { - None - } - } - - pub fn to_slice2(&self) -> &'_ [u8] { - &self.buf.deref()[self.range.clone()] - } -} - -impl<'a> Deref for GenericBuffer<'a> { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - &self.buf.deref()[self.range.clone()] - } -} - -impl<'a> From<&'a [u8]> for GenericBuffer<'a> { - fn from(value: &'a [u8]) -> Self { - GenericBuffer { - buf: value.into(), - range: (0..value.len()), - } - } -} - -impl From> for GenericBuffer<'_> { - fn from(value: Vec) -> Self { - let len = value.len(); - GenericBuffer { buf: value.into(), range: (0..len) } - } -} - -#[derive(Debug)] -pub struct GenericBufferMut<'a> { - buf: WriteableBuffer<'a>, - range: std::ops::Range, -} - -impl<'a> GenericBufferMut<'a> { - /// Create a buffer mut with append some bytes at first and some bytes at end - pub fn build(data: &[u8], more_left: usize, more_right: usize) -> GenericBufferMut<'static> { - let mut v = Vec::with_capacity(more_left + data.len() + more_right); - unsafe { - v.set_len(v.capacity()); - }; - v[more_left..more_left + data.len()].copy_from_slice(data); - GenericBufferMut { - buf: v.into(), - range: (more_left..more_left + data.len()), - } - } - - /// Create a new buffer from a slice, we can manually set the start and end of the buffer. - pub fn from_slice_raw(data: &'a mut [u8], range: std::ops::Range) -> GenericBufferMut<'a> { - assert!(range.end <= data.len()); - GenericBufferMut { buf: data.into(), range } - } - - /// Create a new buffer from a vec, we can manually set the start and end of the buffer. - pub fn from_vec_raw(data: Vec, range: std::ops::Range) -> GenericBufferMut<'a> { - assert!(range.end <= data.len()); - GenericBufferMut { buf: data.into(), range } - } - - pub fn freeze(self) -> GenericBuffer<'a> { - GenericBuffer { - buf: self.buf.freeze(), - range: self.range, - } - } - - pub fn copy_readonly(&self) -> GenericBuffer<'static> { - GenericBuffer { - buf: self.buf.copy_readonly(), - range: self.range.clone(), - } - } - - /// Reverse the buffer for at least `len` bytes at back. - pub fn ensure_back(&mut self, more: usize) { - assert!(self.buf.len() >= self.range.end); - let remain = self.buf.len() - self.range.end; - if remain < more { - self.buf.ensure_back(more - remain); - } - } - - pub fn truncate(&mut self, len: usize) -> Option<()> { - if self.range.end - self.range.start >= len { - self.range.end = self.range.start + len; - Some(()) - } else { - None - } - } - - pub fn push_back(&mut self, data: &[u8]) { - self.ensure_back(data.len()); - self.buf.deref_mut()[self.range.end..self.range.end + data.len()].copy_from_slice(data); - self.range.end += data.len(); - } - - pub fn pop_back(&mut self, len: usize) -> Option> { - if self.range.end - self.range.start >= len { - self.range.end -= len; - Some(GenericBuffer { - buf: self.buf.view(), - range: (self.range.end..self.range.end + len), - }) - } else { - None - } - } - - pub fn pop_front(&mut self, len: usize) -> Option> { - if self.range.end - self.range.start >= len { - self.range.start += len; - Some(GenericBuffer { - buf: self.buf.view(), - range: (self.range.start - len..self.range.start), - }) - } else { - None - } - } - - pub fn move_front_right(&mut self, len: usize) -> Option<()> { - if self.range.start + len > self.range.end { - return None; - } - self.range.start += len; - Some(()) - } - - pub fn move_front_left(&mut self, len: usize) -> Option<()> { - if self.range.start < len { - return None; - } - self.range.start -= len; - Some(()) - } -} - -impl<'a> From<&'a mut [u8]> for GenericBufferMut<'a> { - fn from(value: &'a mut [u8]) -> Self { - GenericBufferMut { - range: (0..value.len()), - buf: value.into(), - } - } -} - -impl From> for GenericBufferMut<'_> { - fn from(value: Vec) -> Self { - GenericBufferMut { - range: (0..value.len()), - buf: value.into(), - } - } -} - -impl<'a> Deref for GenericBufferMut<'_> { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - &self.buf.deref()[self.range.start..self.range.end] - } -} - -impl<'a> DerefMut for GenericBufferMut<'_> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.buf.deref_mut()[self.range.start..self.range.end] - } -} - -#[cfg(test)] -mod tests { - use std::ops::Deref; - - use super::{GenericBuffer, GenericBufferMut}; - - #[test] - fn simple_buffer_view() { - let data = vec![1, 2, 3, 4, 5, 6]; - let mut buf: GenericBuffer = (data.as_slice()).into(); - assert_eq!(buf.len(), 6); - assert_eq!(buf.pop_back(2).expect("").deref(), &[5, 6]); - assert_eq!(buf.pop_front(2).expect("").deref(), &[1, 2]); - assert_eq!(buf.deref(), &[3, 4]); - } - - #[test] - fn simple_buffer_mut() { - let mut buf = GenericBufferMut::build(&[1, 2, 3, 4, 5, 6], 4, 4); - assert_eq!(buf.deref(), &[1, 2, 3, 4, 5, 6]); - assert_eq!(buf.to_vec(), vec![1, 2, 3, 4, 5, 6]); - println!("{:?}", buf); - let res = buf.pop_back(2).expect(""); - println!("{:?}", res); - assert_eq!(res.deref(), &[5, 6]); - assert_eq!(res.to_vec(), &[5, 6]); - assert_eq!(buf.pop_front(2).expect("").deref(), &[1, 2]); - } -} diff --git a/packages/network/src/base/feature.rs b/packages/network/src/base/feature.rs index db95d58b..36bf0429 100644 --- a/packages/network/src/base/feature.rs +++ b/packages/network/src/base/feature.rs @@ -3,7 +3,9 @@ use std::net::SocketAddr; use atm0s_sdn_identity::{ConnId, NodeAddr, NodeId}; use atm0s_sdn_router::{shadow::ShadowRouter, RouteRule}; -use super::{ConnectionCtx, ConnectionEvent, GenericBuffer, GenericBufferMut, ServiceId, TransportMsgHeader, Ttl}; +#[cfg(feature = "vpn")] +use super::BufferMut; +use super::{Buffer, ConnectionCtx, ConnectionEvent, ServiceId, TransportMsgHeader, Ttl}; /// /// @@ -46,6 +48,15 @@ impl NetOutgoingMeta { Self { source, ttl, meta, secure } } + pub fn secure() -> Self { + Self { + source: false, + ttl: Ttl::default(), + meta: 0, + secure: true, + } + } + pub fn to_header(&self, feature: u8, rule: RouteRule, node_id: NodeId) -> TransportMsgHeader { TransportMsgHeader::build(feature, self.meta, rule) .set_ttl(*self.ttl) @@ -74,6 +85,7 @@ impl NetOutgoingMeta { #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] pub enum FeatureControlActor { Controller, + Worker(u16), Service(ServiceId), } @@ -137,9 +149,10 @@ pub enum FeatureWorkerInput<'a, Control, ToWorker> { /// First bool is flag for broadcast or not FromController(bool, ToWorker), Control(FeatureControlActor, Control), - Network(ConnId, NetIncomingMeta, GenericBuffer<'a>), - Local(NetIncomingMeta, GenericBuffer<'a>), - TunPkt(GenericBufferMut<'a>), + Network(ConnId, NetIncomingMeta, Buffer<'a>), + Local(NetIncomingMeta, Buffer<'a>), + #[cfg(feature = "vpn")] + TunPkt(BufferMut<'a>), } #[derive(Clone)] @@ -151,11 +164,12 @@ pub enum FeatureWorkerOutput<'a, Control, Event, ToController> { Event(FeatureControlActor, Event), SendDirect(ConnId, NetOutgoingMeta, Vec), SendRoute(RouteRule, NetOutgoingMeta, Vec), - RawDirect(ConnId, GenericBuffer<'a>), - RawBroadcast(Vec, GenericBuffer<'a>), - RawDirect2(SocketAddr, GenericBuffer<'a>), - RawBroadcast2(Vec, GenericBuffer<'a>), - TunPkt(GenericBuffer<'a>), + RawDirect(ConnId, Buffer<'a>), + RawBroadcast(Vec, Buffer<'a>), + RawDirect2(SocketAddr, Buffer<'a>), + RawBroadcast2(Vec, Buffer<'a>), + #[cfg(feature = "vpn")] + TunPkt(Buffer<'a>), } impl<'a, Control, Event, ToController> FeatureWorkerOutput<'a, Control, Event, ToController> { @@ -177,6 +191,7 @@ impl<'a, Control, Event, ToController> FeatureWorkerOutput<'a, Control, Event, T FeatureWorkerOutput::RawBroadcast(conns, buf) => FeatureWorkerOutput::RawBroadcast(conns, buf), FeatureWorkerOutput::RawDirect2(conn, buf) => FeatureWorkerOutput::RawDirect2(conn, buf), FeatureWorkerOutput::RawBroadcast2(conns, buf) => FeatureWorkerOutput::RawBroadcast2(conns, buf), + #[cfg(feature = "vpn")] FeatureWorkerOutput::TunPkt(buf) => FeatureWorkerOutput::TunPkt(buf), } } @@ -194,6 +209,7 @@ impl<'a, Control, Event, ToController> FeatureWorkerOutput<'a, Control, Event, T FeatureWorkerOutput::RawBroadcast(conns, buf) => FeatureWorkerOutput::RawBroadcast(conns, buf.owned()), FeatureWorkerOutput::RawDirect2(conn, buf) => FeatureWorkerOutput::RawDirect2(conn, buf.owned()), FeatureWorkerOutput::RawBroadcast2(conns, buf) => FeatureWorkerOutput::RawBroadcast2(conns, buf.owned()), + #[cfg(feature = "vpn")] FeatureWorkerOutput::TunPkt(buf) => FeatureWorkerOutput::TunPkt(buf.owned()), } } @@ -213,7 +229,7 @@ pub trait FeatureWorker { conn: ConnId, _remote: SocketAddr, header: TransportMsgHeader, - mut buf: GenericBuffer<'a>, + mut buf: Buffer<'a>, ) -> Option> { let header_len = header.serialize_size(); buf.pop_front(header_len).expect("Buffer should bigger or equal header"); @@ -223,6 +239,7 @@ pub trait FeatureWorker { match input { FeatureWorkerInput::Control(actor, control) => Some(FeatureWorkerOutput::ForwardControlToController(actor, control)), FeatureWorkerInput::Network(conn, header, buf) => Some(FeatureWorkerOutput::ForwardNetworkToController(conn, header, buf.to_vec())), + #[cfg(feature = "vpn")] FeatureWorkerInput::TunPkt(_buf) => None, FeatureWorkerInput::FromController(_, _) => { log::warn!("No handler for FromController"); diff --git a/packages/network/src/base/mod.rs b/packages/network/src/base/mod.rs index e5fa9b50..76828d3e 100644 --- a/packages/network/src/base/mod.rs +++ b/packages/network/src/base/mod.rs @@ -1,4 +1,3 @@ -mod buf; mod control; mod feature; mod msg; @@ -8,10 +7,10 @@ mod service; use std::net::SocketAddr; use atm0s_sdn_identity::{ConnId, NodeId}; -pub use buf::*; pub use control::*; pub use feature::*; pub use msg::*; +pub use sans_io_runtime::{Buffer, BufferMut}; pub use secure::*; pub use service::*; diff --git a/packages/network/src/base/secure.rs b/packages/network/src/base/secure.rs index 81e1c895..4da5de92 100644 --- a/packages/network/src/base/secure.rs +++ b/packages/network/src/base/secure.rs @@ -2,7 +2,7 @@ use std::fmt::Debug; use atm0s_sdn_identity::NodeId; -use super::GenericBufferMut; +use super::BufferMut; #[derive(Debug, Clone)] pub struct SecureContext { @@ -46,7 +46,7 @@ pub enum EncryptionError { #[mockall::automock] pub trait Encryptor: Debug + Send + Sync { - fn encrypt<'a>(&mut self, now_ms: u64, data: &mut GenericBufferMut<'a>) -> Result<(), EncryptionError>; + fn encrypt<'a>(&mut self, now_ms: u64, data: &mut BufferMut<'a>) -> Result<(), EncryptionError>; fn clone_box(&self) -> Box; } @@ -65,7 +65,7 @@ pub enum DecryptionError { #[mockall::automock] pub trait Decryptor: Debug + Send + Sync { - fn decrypt<'a>(&mut self, now_ms: u64, data: &mut GenericBufferMut<'a>) -> Result<(), DecryptionError>; + fn decrypt<'a>(&mut self, now_ms: u64, data: &mut BufferMut<'a>) -> Result<(), DecryptionError>; fn clone_box(&self) -> Box; } diff --git a/packages/network/src/base/service.rs b/packages/network/src/base/service.rs index 706a90c5..583af2e5 100644 --- a/packages/network/src/base/service.rs +++ b/packages/network/src/base/service.rs @@ -10,6 +10,7 @@ simple_pub_type!(ServiceId, u8); #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] pub enum ServiceControlActor { Controller, + Worker(u16), } #[derive(Debug, Clone)] @@ -48,12 +49,14 @@ pub trait Service { +pub enum ServiceWorkerInput { + Control(ServiceControlActor, ServiceControl), FromController(ToWorker), FeatureEvent(FeaturesEvent), } -pub enum ServiceWorkerOutput { +pub enum ServiceWorkerOutput { + ForwardControlToController(ServiceControlActor, ServiceControl), ForwardFeatureEventToController(FeaturesEvent), ToController(ToController), FeatureControl(FeaturesControl), @@ -64,7 +67,7 @@ pub struct ServiceWorkerCtx { pub node_id: NodeId, } -pub trait ServiceWorker { +pub trait ServiceWorker { fn service_id(&self) -> u8; fn service_name(&self) -> &str; fn on_tick(&mut self, _ctx: &ServiceWorkerCtx, _now: u64, _tick_count: u64) {} @@ -72,9 +75,10 @@ pub trait ServiceWorker, - ) -> Option> { + input: ServiceWorkerInput, + ) -> Option> { match input { + ServiceWorkerInput::Control(actor, control) => Some(ServiceWorkerOutput::ForwardControlToController(actor, control)), ServiceWorkerInput::FeatureEvent(event) => Some(ServiceWorkerOutput::ForwardFeatureEventToController(event)), ServiceWorkerInput::FromController(_) => { log::warn!("No handler for FromController in {}", self.service_name()); @@ -82,7 +86,7 @@ pub trait ServiceWorker Option> { + fn pop_output(&mut self, _ctx: &ServiceWorkerCtx) -> Option> { None } } @@ -94,5 +98,5 @@ pub trait ServiceBuilder Box>; - fn create_worker(&self) -> Box>; + fn create_worker(&self) -> Box>; } diff --git a/packages/network/src/controller_plane.rs b/packages/network/src/controller_plane.rs index 240cdc38..d0cedb77 100644 --- a/packages/network/src/controller_plane.rs +++ b/packages/network/src/controller_plane.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{collections::VecDeque, sync::Arc}; use atm0s_sdn_identity::NodeId; use rand::RngCore; @@ -20,9 +20,9 @@ mod neighbours; mod services; #[derive(Debug, Clone, convert_enum::From)] -pub enum Input { +pub enum Input { Ext(ExtIn), - Control(LogicControl), + Control(LogicControl), #[convert_enum(optout)] ShutdownRequest, } @@ -30,7 +30,7 @@ pub enum Input { #[derive(Debug, Clone, convert_enum::From)] pub enum Output { Ext(ExtOut), - Event(LogicEvent), + Event(LogicEvent), #[convert_enum(optout)] ShutdownSuccess, } @@ -59,6 +59,14 @@ impl TryFrom for TaskType { } } +pub struct ControllerPlaneCfg { + pub session: u64, + pub services: Vec>>, + pub authorization: Arc, + pub handshake_builder: Arc, + pub random: Box, +} + pub struct ControllerPlane { tick_count: u64, neighbours: NeighboursManager, @@ -67,6 +75,7 @@ pub struct ControllerPlane { service_ctx: ServiceCtx, services: ServiceManager, switcher: TaskSwitcher, + queue: VecDeque>, } impl ControllerPlane { @@ -80,29 +89,24 @@ impl ControllerPlane { /// # Returns /// /// A new ControllerPlane - pub fn new( - node_id: NodeId, - session: u64, - services: Vec>>, - authorization: Arc, - handshake_builder: Arc, - random: Box, - ) -> Self { - log::info!("Create ControllerPlane for node: {}, running session {}", node_id, session); - let service_ids = services.iter().filter(|s| s.discoverable()).map(|s| s.service_id()).collect(); + pub fn new(node_id: NodeId, cfg: ControllerPlaneCfg) -> Self { + log::info!("Create ControllerPlane for node: {}, running session {}", node_id, cfg.session); + let service_ids = cfg.services.iter().filter(|s| s.discoverable()).map(|s| s.service_id()).collect(); Self { tick_count: 0, - neighbours: NeighboursManager::new(node_id, authorization, handshake_builder, random), - feature_ctx: FeatureContext { node_id, session }, - features: FeatureManager::new(node_id, session, service_ids), - service_ctx: ServiceCtx { node_id, session }, - services: ServiceManager::new(services), + neighbours: NeighboursManager::new(node_id, cfg.authorization, cfg.handshake_builder, cfg.random), + feature_ctx: FeatureContext { node_id, session: cfg.session }, + features: FeatureManager::new(node_id, cfg.session, service_ids), + service_ctx: ServiceCtx { node_id, session: cfg.session }, + services: ServiceManager::new(cfg.services), switcher: TaskSwitcher::new(3), //3 types: Neighbours, Feature, Service + queue: VecDeque::new(), } } pub fn on_tick(&mut self, now_ms: u64) { + log::trace!("[ControllerPlane] on_tick: {}", now_ms); self.switcher.queue_flag_all(); self.neighbours.on_tick(now_ms, self.tick_count); self.features.on_shared_input(&self.feature_ctx, now_ms, FeatureSharedInput::Tick(self.tick_count)); @@ -110,7 +114,7 @@ impl ControllerPlane { self.tick_count += 1; } - pub fn on_event(&mut self, now_ms: u64, event: Input) { + pub fn on_event(&mut self, now_ms: u64, event: Input) { match event { Input::Ext(ExtIn::ConnectTo(addr)) => { self.switcher.queue_flag_task(TaskType::Neighbours as usize); @@ -156,10 +160,20 @@ impl ControllerPlane { self.switcher.queue_flag_task(TaskType::Service as usize); self.services.on_input(&self.service_ctx, now_ms, service, ServiceInput::FeatureEvent(event)); } + Input::Control(LogicControl::ServicesControl(actor, service, control)) => { + self.switcher.queue_flag_task(TaskType::Service as usize); + self.services.on_input(&self.service_ctx, now_ms, service, ServiceInput::Control(actor, control)); + } Input::Control(LogicControl::FeaturesControl(actor, control)) => { self.switcher.queue_flag_task(TaskType::Feature as usize); self.features.on_input(&self.feature_ctx, now_ms, control.to_feature(), FeatureInput::Control(actor, control)); } + Input::Control(LogicControl::ExtFeaturesEvent(event)) => { + self.queue.push_back(Output::Ext(ExtOut::FeaturesEvent(event))); + } + Input::Control(LogicControl::ExtServicesEvent(service, event)) => { + self.queue.push_back(Output::Ext(ExtOut::ServicesEvent(service, event))); + } Input::ShutdownRequest => { self.switcher.queue_flag_task(TaskType::Neighbours as usize); self.neighbours.on_input(now_ms, neighbours::Input::ShutdownRequest); @@ -168,6 +182,10 @@ impl ControllerPlane { } pub fn pop_output(&mut self, now_ms: u64) -> Option> { + while let Some(out) = self.queue.pop_front() { + return Some(out); + } + while let Some(current) = self.switcher.queue_current() { match (current as u8).try_into().expect("Should convert to TaskType") { TaskType::Neighbours => { @@ -195,21 +213,26 @@ impl ControllerPlane { } fn pop_neighbours(&mut self, now_ms: u64) -> Option> { - let out = self.neighbours.pop_output(now_ms)?; - match out { - neighbours::Output::Control(remote, control) => Some(Output::Event(LogicEvent::NetNeighbour(remote, control))), - neighbours::Output::Event(event) => { - self.switcher.queue_flag_task(TaskType::Feature as usize); - self.features.on_shared_input(&self.feature_ctx, now_ms, FeatureSharedInput::Connection(event.clone())); - self.switcher.queue_flag_task(TaskType::Service as usize); - self.services.on_shared_input(&self.service_ctx, now_ms, ServiceSharedInput::Connection(event.clone())); - match event { - ConnectionEvent::Connected(ctx, secure) => Some(Output::Event(LogicEvent::Pin(ctx.conn, ctx.node, ctx.remote, secure))), - ConnectionEvent::Stats(_ctx, _stats) => None, - ConnectionEvent::Disconnected(ctx) => Some(Output::Event(LogicEvent::UnPin(ctx.conn))), + loop { + let out = self.neighbours.pop_output(now_ms)?; + let out = match out { + neighbours::Output::Control(remote, control) => Some(Output::Event(LogicEvent::NetNeighbour(remote, control))), + neighbours::Output::Event(event) => { + self.switcher.queue_flag_task(TaskType::Feature as usize); + self.features.on_shared_input(&self.feature_ctx, now_ms, FeatureSharedInput::Connection(event.clone())); + self.switcher.queue_flag_task(TaskType::Service as usize); + self.services.on_shared_input(&self.service_ctx, now_ms, ServiceSharedInput::Connection(event.clone())); + match event { + ConnectionEvent::Connected(ctx, secure) => Some(Output::Event(LogicEvent::Pin(ctx.conn, ctx.node, ctx.remote, secure))), + ConnectionEvent::Stats(_ctx, _stats) => None, + ConnectionEvent::Disconnected(ctx) => Some(Output::Event(LogicEvent::UnPin(ctx.conn))), + } } + neighbours::Output::ShutdownResponse => Some(Output::ShutdownSuccess), + }; + if out.is_some() { + return out; } - neighbours::Output::ShutdownResponse => Some(Output::ShutdownSuccess), } } @@ -222,6 +245,7 @@ impl ControllerPlane { log::debug!("[Controller] send FeatureEvent to actor {:?}, event {:?}", actor, event); match actor { FeatureControlActor::Controller => Some(Output::Ext(ExtOut::FeaturesEvent(event))), + FeatureControlActor::Worker(worker) => Some(Output::Event(LogicEvent::ExtFeaturesEvent(worker, event))), FeatureControlActor::Service(service) => { self.switcher.queue_flag_task(TaskType::Service as usize); self.services.on_input(&self.service_ctx, now_ms, service, ServiceInput::FeatureEvent(event)); @@ -260,7 +284,8 @@ impl ControllerPlane { self.pop_features(now_ms) } ServiceOutput::Event(actor, event) => match actor { - ServiceControlActor::Controller => Some(Output::Ext(ExtOut::ServicesEvent(event))), + ServiceControlActor::Controller => Some(Output::Ext(ExtOut::ServicesEvent(service, event))), + ServiceControlActor::Worker(worker) => Some(Output::Event(LogicEvent::ExtServicesEvent(worker, service, event))), }, ServiceOutput::BroadcastWorkers(to) => Some(Output::Event(LogicEvent::Service(service, to))), } diff --git a/packages/network/src/controller_plane/neighbours.rs b/packages/network/src/controller_plane/neighbours.rs index e1a1da16..648b0a7d 100644 --- a/packages/network/src/controller_plane/neighbours.rs +++ b/packages/network/src/controller_plane/neighbours.rs @@ -91,7 +91,7 @@ impl NeighboursManager { } }; - log::debug!("[NeighboursManager] received Control(addr: {:?}, control: {:?})", addr, control); + log::debug!("[NeighboursManager] received Control(addr: {:?}, cmd: {:?})", addr, cmd); if let Some(conn) = self.connections.get_mut(&addr) { conn.on_input(now_ms, control.from, cmd); } else { diff --git a/packages/network/src/data_plane.rs b/packages/network/src/data_plane.rs index 73843d78..d78a2984 100644 --- a/packages/network/src/data_plane.rs +++ b/packages/network/src/data_plane.rs @@ -13,11 +13,11 @@ use sans_io_runtime::TaskSwitcher; use crate::{ base::{ - FeatureControlActor, FeatureWorkerContext, FeatureWorkerInput, FeatureWorkerOutput, GenericBuffer, GenericBufferMut, NeighboursControl, NetOutgoingMeta, ServiceBuilder, ServiceControlActor, - ServiceId, ServiceWorkerCtx, ServiceWorkerInput, ServiceWorkerOutput, TransportMsg, TransportMsgHeader, + Buffer, BufferMut, FeatureControlActor, FeatureWorkerContext, FeatureWorkerInput, FeatureWorkerOutput, NeighboursControl, NetOutgoingMeta, ServiceBuilder, ServiceControlActor, ServiceId, + ServiceWorkerCtx, ServiceWorkerInput, ServiceWorkerOutput, TransportMsg, TransportMsgHeader, }, features::{Features, FeaturesControl, FeaturesEvent, FeaturesToController}, - ExtOut, LogicControl, LogicEvent, + ExtIn, ExtOut, LogicControl, LogicEvent, }; use self::{connection::DataPlaneConnection, features::FeatureWorkerManager, services::ServiceWorkerManager}; @@ -28,29 +28,41 @@ mod services; #[derive(Debug)] pub enum NetInput<'a> { - UdpPacket(SocketAddr, GenericBufferMut<'a>), - TunPacket(GenericBufferMut<'a>), + UdpPacket(SocketAddr, BufferMut<'a>), + #[cfg(feature = "vpn")] + TunPacket(BufferMut<'a>), +} + +#[derive(Debug, Clone)] +pub enum CrossWorker { + Feature(FeaturesEvent), + Service(ServiceId, SE), } #[derive(Debug)] -pub enum Input<'a, TW> { +pub enum Input<'a, SC, SE, TW> { + Ext(ExtIn), Net(NetInput<'a>), - Event(LogicEvent), + Event(LogicEvent), + Worker(CrossWorker), ShutdownRequest, } #[derive(Debug)] pub enum NetOutput<'a> { - UdpPacket(SocketAddr, GenericBuffer<'a>), - UdpPackets(Vec, GenericBuffer<'a>), - TunPacket(GenericBuffer<'a>), + UdpPacket(SocketAddr, Buffer<'a>), + UdpPackets(Vec, Buffer<'a>), + #[cfg(feature = "vpn")] + TunPacket(Buffer<'a>), } #[derive(convert_enum::From)] -pub enum Output<'a, SE, TC> { +pub enum Output<'a, SC, SE, TC> { Ext(ExtOut), Net(NetOutput<'a>), - Control(LogicControl), + Control(LogicControl), + #[convert_enum(optout)] + Worker(u16, CrossWorker), #[convert_enum(optout)] ShutdownResponse, #[convert_enum(optout)] @@ -63,37 +75,45 @@ enum TaskType { Service, } -enum QueueOutput { +enum QueueOutput { Feature(Features, FeatureWorkerOutput<'static, FeaturesControl, FeaturesEvent, FeaturesToController>), - Service(ServiceId, ServiceWorkerOutput), + Service(ServiceId, ServiceWorkerOutput), Net(NetOutput<'static>), } +pub struct DataPlaneCfg { + pub worker_id: u16, + pub services: Vec>>, + pub history: Arc, +} + pub struct DataPlane { tick_count: u64, + worker_id: u16, feature_ctx: FeatureWorkerContext, features: FeatureWorkerManager, service_ctx: ServiceWorkerCtx, services: ServiceWorkerManager, conns: HashMap, conns_reverse: HashMap, - queue_output: VecDeque>, + queue_output: VecDeque>, switcher: TaskSwitcher, } impl DataPlane { - pub fn new(node_id: NodeId, services: Vec>>, history: Arc) -> Self { + pub fn new(node_id: NodeId, cfg: DataPlaneCfg) -> Self { log::info!("Create DataPlane for node: {}", node_id); Self { + worker_id: cfg.worker_id, tick_count: 0, feature_ctx: FeatureWorkerContext { node_id, - router: ShadowRouter::new(node_id, history), + router: ShadowRouter::new(node_id, cfg.history), }, features: FeatureWorkerManager::new(), service_ctx: ServiceWorkerCtx { node_id }, - services: ServiceWorkerManager::new(services), + services: ServiceWorkerManager::new(cfg.services), conns: HashMap::new(), conns_reverse: HashMap::new(), queue_output: VecDeque::new(), @@ -106,14 +126,36 @@ impl DataPlane { } pub fn on_tick<'a>(&mut self, now_ms: u64) { + log::trace!("[DataPlane] on_tick: {}", now_ms); self.switcher.queue_flag_all(); self.features.on_tick(&mut self.feature_ctx, now_ms, self.tick_count); self.services.on_tick(&mut self.service_ctx, now_ms, self.tick_count); self.tick_count += 1; } - pub fn on_event<'a>(&mut self, now_ms: u64, event: Input<'a, TW>) -> Option> { + pub fn on_event<'a>(&mut self, now_ms: u64, event: Input<'a, SC, SE, TW>) -> Option> { match event { + Input::Ext(ext) => match ext { + ExtIn::ConnectTo(_remote) => { + panic!("ConnectTo is not supported") + } + ExtIn::DisconnectFrom(_node) => { + panic!("DisconnectFrom is not supported") + } + ExtIn::FeaturesControl(control) => { + let feature: Features = control.to_feature(); + let actor = FeatureControlActor::Worker(self.worker_id); + let out = self.features.on_input(&mut self.feature_ctx, feature, now_ms, FeatureWorkerInput::Control(actor, control))?; + Some(self.convert_features(now_ms, feature, out)) + } + ExtIn::ServicesControl(service, control) => { + let actor = ServiceControlActor::Worker(self.worker_id); + let out = self.services.on_input(&mut self.service_ctx, now_ms, service, ServiceWorkerInput::Control(actor, control))?; + Some(self.convert_services(now_ms, service, out)) + } + }, + Input::Worker(CrossWorker::Feature(event)) => Some(Output::Ext(ExtOut::FeaturesEvent(event))), + Input::Worker(CrossWorker::Service(service, event)) => Some(Output::Ext(ExtOut::ServicesEvent(service, event))), Input::Net(NetInput::UdpPacket(remote, buf)) => { if buf.is_empty() { return None; @@ -124,6 +166,7 @@ impl DataPlane { self.incoming_route(now_ms, remote, buf) } } + #[cfg(feature = "vpn")] Input::Net(NetInput::TunPacket(pkt)) => { let out = self.features.on_input(&mut self.feature_ctx, Features::Vpn, now_ms, FeatureWorkerInput::TunPkt(pkt))?; Some(self.convert_features(now_ms, Features::Vpn, out)) @@ -137,9 +180,17 @@ impl DataPlane { let out = self.services.on_input(&mut self.service_ctx, now_ms, service, ServiceWorkerInput::FromController(to))?; Some(self.convert_services(now_ms, service, out)) } + Input::Event(LogicEvent::ExtFeaturesEvent(worker, event)) => { + assert_eq!(self.worker_id, worker); + Some(Output::Ext(ExtOut::FeaturesEvent(event))) + } + Input::Event(LogicEvent::ExtServicesEvent(worker, service, event)) => { + assert_eq!(self.worker_id, worker); + Some(Output::Ext(ExtOut::ServicesEvent(service, event))) + } Input::Event(LogicEvent::NetNeighbour(remote, control)) => { let buf: Result, _> = (&control).try_into(); - Some(NetOutput::UdpPacket(remote, GenericBuffer::from(buf.ok()?)).into()) + Some(NetOutput::UdpPacket(remote, Buffer::from(buf.ok()?)).into()) } Input::Event(LogicEvent::NetDirect(feature, remote, _conn, meta, buf)) => { let header = meta.to_header(feature as u8, RouteRule::Direct, self.feature_ctx.node_id); @@ -164,7 +215,7 @@ impl DataPlane { } } - pub fn pop_output<'a>(&mut self, now_ms: u64) -> Option> { + pub fn pop_output<'a>(&mut self, now_ms: u64) -> Option> { if let Some(out) = self.queue_output.pop_front() { return match out { QueueOutput::Feature(feature, out) => Some(self.convert_features(now_ms, feature, out)), @@ -194,14 +245,14 @@ impl DataPlane { None } - fn incoming_route<'a>(&mut self, now_ms: u64, remote: SocketAddr, mut buf: GenericBufferMut<'a>) -> Option> { + fn incoming_route<'a>(&mut self, now_ms: u64, remote: SocketAddr, mut buf: BufferMut<'a>) -> Option> { let conn = self.conns.get_mut(&remote)?; if TransportMsgHeader::is_secure(buf[0]) { conn.decrypt_if_need(now_ms, &mut buf)?; } let header = TransportMsgHeader::try_from(&buf as &[u8]).ok()?; let action = self.feature_ctx.router.derive_action(&header.route, header.from_node, Some(conn.node())); - log::debug!("[DataPlame] Incoming rule: {:?} from: {remote}, node {:?} => action {:?}", header.route, header.from_node, action); + log::debug!("[DataPlane] Incoming rule: {:?} from: {remote}, node {:?} => action {:?}", header.route, header.from_node, action); match action { RouteAction::Reject => None, RouteAction::Local => { @@ -230,12 +281,16 @@ impl DataPlane { } } } - self.build_send_to_multi_from_mut(now_ms, remotes, buf).map(|e: NetOutput<'_>| e.into()) + if remotes.is_empty() { + self.pop_output(now_ms).map(|e| e.into()) + } else { + self.build_send_to_multi_from_mut(now_ms, remotes, buf).map(|e: NetOutput<'_>| e.into()) + } } } } - fn outgoing_route<'a>(&mut self, now_ms: u64, feature: Features, rule: RouteRule, mut meta: NetOutgoingMeta, buf: Vec) -> Option> { + fn outgoing_route<'a>(&mut self, now_ms: u64, feature: Features, rule: RouteRule, mut meta: NetOutgoingMeta, buf: Vec) -> Option> { match self.feature_ctx.router.derive_action(&rule, Some(self.feature_ctx.node_id), None) { RouteAction::Reject => { log::debug!("[DataPlane] outgoing route rule {:?} is rejected", rule); @@ -272,7 +327,7 @@ impl DataPlane { } } - fn convert_features<'a>(&mut self, now_ms: u64, feature: Features, out: FeatureWorkerOutput<'a, FeaturesControl, FeaturesEvent, FeaturesToController>) -> Output<'a, SE, TC> { + fn convert_features<'a>(&mut self, now_ms: u64, feature: Features, out: FeatureWorkerOutput<'a, FeaturesControl, FeaturesEvent, FeaturesToController>) -> Output<'a, SC, SE, TC> { self.switcher.queue_flag_task(TaskType::Feature as usize); match out { @@ -281,13 +336,20 @@ impl DataPlane { FeatureWorkerOutput::ForwardLocalToController(header, buf) => LogicControl::NetLocal(feature, header, buf).into(), FeatureWorkerOutput::ToController(control) => LogicControl::Feature(control).into(), FeatureWorkerOutput::Event(actor, event) => match actor { - FeatureControlActor::Controller => Output::Ext(ExtOut::FeaturesEvent(event)), + FeatureControlActor::Controller => Output::Control(LogicControl::ExtFeaturesEvent(event)), + FeatureControlActor::Worker(worker) => { + if self.worker_id == worker { + Output::Ext(ExtOut::FeaturesEvent(event)) + } else { + Output::Worker(worker, CrossWorker::Feature(event)) + } + } FeatureControlActor::Service(service) => { if let Some(out) = self.services.on_input(&mut self.service_ctx, now_ms, service, ServiceWorkerInput::FeatureEvent(event)) { - self.switcher.queue_flag_task(TaskType::Service as usize); - self.queue_output.push_back(QueueOutput::Service(service, out)); + self.convert_services(now_ms, service, out) + } else { + Output::Continue } - Output::Continue } }, FeatureWorkerOutput::SendDirect(conn, meta, buf) => { @@ -328,14 +390,16 @@ impl DataPlane { } } FeatureWorkerOutput::RawBroadcast2(addrs, buf) => self.build_send_to_multi(now_ms, addrs, buf).map(|e| e.into()).unwrap_or(Output::Continue), + #[cfg(feature = "vpn")] FeatureWorkerOutput::TunPkt(pkt) => NetOutput::TunPacket(pkt).into(), } } - fn convert_services<'a>(&mut self, now_ms: u64, service: ServiceId, out: ServiceWorkerOutput) -> Output<'a, SE, TC> { + fn convert_services<'a>(&mut self, now_ms: u64, service: ServiceId, out: ServiceWorkerOutput) -> Output<'a, SC, SE, TC> { self.switcher.queue_flag_task(TaskType::Service as usize); match out { + ServiceWorkerOutput::ForwardControlToController(actor, control) => LogicControl::ServicesControl(actor, service, control).into(), ServiceWorkerOutput::ForwardFeatureEventToController(event) => LogicControl::ServiceEvent(service, event).into(), ServiceWorkerOutput::ToController(tc) => LogicControl::Service(service, tc).into(), ServiceWorkerOutput::FeatureControl(control) => { @@ -344,29 +408,36 @@ impl DataPlane { .features .on_input(&mut self.feature_ctx, feature, now_ms, FeatureWorkerInput::Control(FeatureControlActor::Service(service), control)) { - self.switcher.queue_flag_task(TaskType::Feature as usize); - self.queue_output.push_back(QueueOutput::Feature(feature, out)); + self.convert_features(now_ms, feature, out) + } else { + Output::Continue } - Output::Continue } ServiceWorkerOutput::Event(actor, event) => match actor { - ServiceControlActor::Controller => Output::Ext(ExtOut::ServicesEvent(event)), + ServiceControlActor::Controller => Output::Control(LogicControl::ExtServicesEvent(service, event)), + ServiceControlActor::Worker(worker) => { + if self.worker_id == worker { + Output::Ext(ExtOut::ServicesEvent(service, event)) + } else { + Output::Worker(worker, CrossWorker::Service(service, event)) + } + } }, } } - fn build_send_to_from_mut<'a>(now: u64, conn: &mut DataPlaneConnection, remote: SocketAddr, mut buf: GenericBufferMut<'a>) -> Option> { + fn build_send_to_from_mut<'a>(now: u64, conn: &mut DataPlaneConnection, remote: SocketAddr, mut buf: BufferMut<'a>) -> Option> { conn.encrypt_if_need(now, &mut buf)?; let after = buf.freeze(); Some(NetOutput::UdpPacket(remote, after)) } - fn build_send_to_multi_from_mut<'a>(&mut self, now: u64, mut remotes: Vec, mut buf: GenericBufferMut<'a>) -> Option> { + fn build_send_to_multi_from_mut<'a>(&mut self, now: u64, mut remotes: Vec, mut buf: BufferMut<'a>) -> Option> { if TransportMsgHeader::is_secure(buf[0]) { let first = remotes.pop()?; for remote in remotes { if let Some(conn) = self.conns.get_mut(&remote) { - let mut buf = GenericBufferMut::build(&buf, 0, 12 + 16); + let mut buf = BufferMut::build(&buf, 0, 12 + 16); if let Some(_) = conn.encrypt_if_need(now, &mut buf) { let out = NetOutput::UdpPacket(remote, buf.freeze()); self.queue_output.push_back(QueueOutput::Net(out)); @@ -381,18 +452,18 @@ impl DataPlane { } } - fn build_send_to_multi<'a>(&mut self, now: u64, remotes: Vec, buf: GenericBuffer<'a>) -> Option> { + fn build_send_to_multi<'a>(&mut self, now: u64, remotes: Vec, buf: Buffer<'a>) -> Option> { if TransportMsgHeader::is_secure(buf[0]) { - let buf = GenericBufferMut::build(&buf, 0, 12 + 16); + let buf = BufferMut::build(&buf, 0, 12 + 16); self.build_send_to_multi_from_mut(now, remotes, buf) } else { Some(NetOutput::UdpPackets(remotes, buf)) } } - fn build_send_to<'a>(now: u64, conn: &mut DataPlaneConnection, remote: SocketAddr, buf: GenericBuffer<'a>) -> Option> { + fn build_send_to<'a>(now: u64, conn: &mut DataPlaneConnection, remote: SocketAddr, buf: Buffer<'a>) -> Option> { if TransportMsgHeader::is_secure(buf[0]) { - let buf = GenericBufferMut::build(&buf, 0, 12 + 16); + let buf = BufferMut::build(&buf, 0, 12 + 16); Self::build_send_to_from_mut(now, conn, remote, buf) } else { Some(NetOutput::UdpPacket(remote, buf)) diff --git a/packages/network/src/data_plane/connection.rs b/packages/network/src/data_plane/connection.rs index 599bd435..cd6daf8b 100644 --- a/packages/network/src/data_plane/connection.rs +++ b/packages/network/src/data_plane/connection.rs @@ -2,7 +2,7 @@ use std::net::SocketAddr; use atm0s_sdn_identity::{ConnId, NodeId}; -use crate::base::{GenericBufferMut, SecureContext, TransportMsgHeader}; +use crate::base::{BufferMut, SecureContext, TransportMsgHeader}; pub struct DataPlaneConnection { node: NodeId, @@ -26,7 +26,7 @@ impl DataPlaneConnection { } /// This will encrypt without first byte, which is used for TransportMsgHeader meta - pub fn encrypt_if_need<'a>(&mut self, now: u64, buf: &mut GenericBufferMut<'a>) -> Option<()> { + pub fn encrypt_if_need<'a>(&mut self, now: u64, buf: &mut BufferMut<'a>) -> Option<()> { if buf.len() < 1 { return None; } @@ -41,7 +41,7 @@ impl DataPlaneConnection { } /// This will encrypt without first byte, which is used for TransportMsgHeader meta - pub fn decrypt_if_need<'a>(&mut self, now: u64, buf: &mut GenericBufferMut<'a>) -> Option<()> { + pub fn decrypt_if_need<'a>(&mut self, now: u64, buf: &mut BufferMut<'a>) -> Option<()> { if buf.len() < 1 { return None; } diff --git a/packages/network/src/data_plane/features.rs b/packages/network/src/data_plane/features.rs index 39b3cf36..8e11e80c 100644 --- a/packages/network/src/data_plane/features.rs +++ b/packages/network/src/data_plane/features.rs @@ -3,7 +3,7 @@ use std::net::SocketAddr; use atm0s_sdn_identity::ConnId; use sans_io_runtime::TaskSwitcher; -use crate::base::{FeatureWorker, FeatureWorkerContext, FeatureWorkerInput, FeatureWorkerOutput, GenericBuffer, TransportMsgHeader}; +use crate::base::{Buffer, FeatureWorker, FeatureWorkerContext, FeatureWorkerInput, FeatureWorkerOutput, TransportMsgHeader}; use crate::features::*; pub type FeaturesWorkerInput<'a> = FeatureWorkerInput<'a, FeaturesControl, FeaturesToWorker>; @@ -62,7 +62,7 @@ impl FeatureWorkerManager { conn: ConnId, remote: SocketAddr, header: TransportMsgHeader, - buf: GenericBuffer<'a>, + buf: Buffer<'a>, ) -> Option> { let out = match feature { Features::Neighbours => self.neighbours.on_network_raw(ctx, now_ms, conn, remote, header, buf).map(|a| a.into2()), @@ -82,15 +82,15 @@ impl FeatureWorkerManager { pub fn on_input<'a>(&mut self, ctx: &mut FeatureWorkerContext, feature: Features, now_ms: u64, input: FeaturesWorkerInput<'a>) -> Option> { let out = match input { - FeatureWorkerInput::Control(service, control) => match control { - FeaturesControl::Neighbours(control) => self.neighbours.on_input(ctx, now_ms, FeatureWorkerInput::Control(service, control)).map(|a| a.into2()), - FeaturesControl::Data(control) => self.data.on_input(ctx, now_ms, FeatureWorkerInput::Control(service, control)).map(|a| a.into2()), - FeaturesControl::RouterSync(control) => self.router_sync.on_input(ctx, now_ms, FeatureWorkerInput::Control(service, control)).map(|a| a.into2()), - FeaturesControl::Vpn(control) => self.vpn.on_input(ctx, now_ms, FeatureWorkerInput::Control(service, control)).map(|a| a.into2()), - FeaturesControl::DhtKv(control) => self.dht_kv.on_input(ctx, now_ms, FeatureWorkerInput::Control(service, control)).map(|a| a.into2()), - FeaturesControl::PubSub(control) => self.pubsub.on_input(ctx, now_ms, FeatureWorkerInput::Control(service, control)).map(|a| a.into2()), - FeaturesControl::Alias(control) => self.alias.on_input(ctx, now_ms, FeatureWorkerInput::Control(service, control)).map(|a| a.into2()), - FeaturesControl::Socket(control) => self.socket.on_input(ctx, now_ms, FeatureWorkerInput::Control(service, control)).map(|a| a.into2()), + FeatureWorkerInput::Control(actor, control) => match control { + FeaturesControl::Neighbours(control) => self.neighbours.on_input(ctx, now_ms, FeatureWorkerInput::Control(actor, control)).map(|a| a.into2()), + FeaturesControl::Data(control) => self.data.on_input(ctx, now_ms, FeatureWorkerInput::Control(actor, control)).map(|a| a.into2()), + FeaturesControl::RouterSync(control) => self.router_sync.on_input(ctx, now_ms, FeatureWorkerInput::Control(actor, control)).map(|a| a.into2()), + FeaturesControl::Vpn(control) => self.vpn.on_input(ctx, now_ms, FeatureWorkerInput::Control(actor, control)).map(|a| a.into2()), + FeaturesControl::DhtKv(control) => self.dht_kv.on_input(ctx, now_ms, FeatureWorkerInput::Control(actor, control)).map(|a| a.into2()), + FeaturesControl::PubSub(control) => self.pubsub.on_input(ctx, now_ms, FeatureWorkerInput::Control(actor, control)).map(|a| a.into2()), + FeaturesControl::Alias(control) => self.alias.on_input(ctx, now_ms, FeatureWorkerInput::Control(actor, control)).map(|a| a.into2()), + FeaturesControl::Socket(control) => self.socket.on_input(ctx, now_ms, FeatureWorkerInput::Control(actor, control)).map(|a| a.into2()), }, FeatureWorkerInput::FromController(is_broadcast, to) => match to { FeaturesToWorker::Neighbours(to) => self.neighbours.on_input(ctx, now_ms, FeatureWorkerInput::FromController(is_broadcast, to)).map(|a| a.into2()), @@ -105,6 +105,7 @@ impl FeatureWorkerManager { FeatureWorkerInput::Network(_conn, _header, _buf) => { panic!("should call above on_network_raw") } + #[cfg(feature = "vpn")] FeatureWorkerInput::TunPkt(pkt) => self.vpn.on_input(ctx, now_ms, FeatureWorkerInput::TunPkt(pkt)).map(|a| a.into2()), FeatureWorkerInput::Local(header, buf) => match feature { Features::Neighbours => self.neighbours.on_input(ctx, now_ms, FeatureWorkerInput::Local(header, buf)).map(|a| a.into2()), diff --git a/packages/network/src/data_plane/services.rs b/packages/network/src/data_plane/services.rs index 99bc32b0..8cc07ed6 100644 --- a/packages/network/src/data_plane/services.rs +++ b/packages/network/src/data_plane/services.rs @@ -8,7 +8,7 @@ use crate::features::{FeaturesControl, FeaturesEvent}; /// To manage the services we need to create an object that will hold the services pub struct ServiceWorkerManager { - services: [Option>>; 256], + services: [Option>>; 256], switcher: TaskSwitcher, _tmp: PhantomData, } @@ -37,8 +37,8 @@ impl ServiceWorkerManager< ctx: &ServiceWorkerCtx, now: u64, id: ServiceId, - input: ServiceWorkerInput, - ) -> Option> { + input: ServiceWorkerInput, + ) -> Option> { let service = self.services[*id as usize].as_mut()?; let out = service.on_input(ctx, now, input); if out.is_some() { @@ -47,7 +47,7 @@ impl ServiceWorkerManager< out } - pub fn pop_output(&mut self, ctx: &ServiceWorkerCtx) -> Option<(ServiceId, ServiceWorkerOutput)> { + pub fn pop_output(&mut self, ctx: &ServiceWorkerCtx) -> Option<(ServiceId, ServiceWorkerOutput)> { loop { let s = &mut self.switcher; let index = s.queue_current()?; diff --git a/packages/network/src/features/alias.rs b/packages/network/src/features/alias.rs index bbb7be44..ef88649f 100644 --- a/packages/network/src/features/alias.rs +++ b/packages/network/src/features/alias.rs @@ -94,7 +94,7 @@ impl AliasFeature { slot.waiters.push(actor); } else if let Some(slot) = self.hint_slots.get(&alias) { if slot.ts + HINT_TIMEOUT_MS >= now_ms { - log::debug!("[AliasFeature] Alias {alias} is very newly added to hint {} => reuse", slot.node); + log::debug!("[AliasFeature] Alias {alias} is very newly added ({} vs now {}) to hint {} => reuse", slot.ts, now_ms, slot.node); self.queue.push_back(FeatureOutput::Event(actor, Event::QueryResult(alias, Some(FoundLocation::CachedHint(slot.node))))); } else { log::debug!("[AliasFeature] Alias {alias} is not in query state but has hint {} => check hint", slot.node); diff --git a/packages/network/src/features/data.rs b/packages/network/src/features/data.rs index 4377ea80..9c463e2b 100644 --- a/packages/network/src/features/data.rs +++ b/packages/network/src/features/data.rs @@ -83,6 +83,7 @@ impl Feature for DataFeature { Control::SendRule(rule, ttl, data) => { let data = match actor { FeatureControlActor::Controller => DataMsg::DataController { data }, + FeatureControlActor::Worker(_) => todo!(), FeatureControlActor::Service(service) => DataMsg::DataService { service: *service, data }, }; let msg = bincode::serialize(&data).expect("should work"); diff --git a/packages/network/src/features/pubsub/controller.rs b/packages/network/src/features/pubsub/controller.rs index a5de9ed8..5a553bc0 100644 --- a/packages/network/src/features/pubsub/controller.rs +++ b/packages/network/src/features/pubsub/controller.rs @@ -5,8 +5,10 @@ use std::{ use crate::base::{ConnectionEvent, Feature, FeatureContext, FeatureControlActor, FeatureInput, FeatureOutput, FeatureSharedInput}; +use self::source_hint::SourceHintLogic; + use super::{ - msg::{ChannelId, RelayControl, RelayId}, + msg::{ChannelId, Feedback, RelayControl, RelayId, SourceHint}, ChannelControl, ChannelEvent, Control, Event, RelayWorkerControl, ToController, ToWorker, }; @@ -14,9 +16,12 @@ pub const RELAY_TIMEOUT: u64 = 10_000; pub const RELAY_STICKY_MS: u64 = 5 * 60 * 1000; //sticky route path in 5 minutes mod consumers; +mod feedbacks; mod local_relay; mod remote_relay; +mod source_hint; +use atm0s_sdn_identity::NodeId; use local_relay::LocalRelay; use remote_relay::RemoteRelay; @@ -24,11 +29,15 @@ use remote_relay::RemoteRelay; pub enum GenericRelayOutput { ToWorker(RelayWorkerControl), RouteChanged(FeatureControlActor), + Feedback(Vec, Feedback), } pub trait GenericRelay { fn on_tick(&mut self, now: u64); + fn on_pub_start(&mut self, actor: FeatureControlActor); + fn on_pub_stop(&mut self, actor: FeatureControlActor); fn on_local_sub(&mut self, now: u64, actor: FeatureControlActor); + fn on_local_feedback(&mut self, now: u64, actor: FeatureControlActor, feedback: Feedback); fn on_local_unsub(&mut self, now: u64, actor: FeatureControlActor); fn on_remote(&mut self, now: u64, remote: SocketAddr, control: RelayControl); fn conn_disconnected(&mut self, now: u64, remote: SocketAddr); @@ -39,6 +48,7 @@ pub trait GenericRelay { pub struct PubSubFeature { relays: HashMap>, + source_hints: HashMap, queue: VecDeque>, } @@ -46,6 +56,7 @@ impl PubSubFeature { pub fn new() -> Self { Self { relays: HashMap::new(), + source_hints: HashMap::new(), queue: VecDeque::new(), } } @@ -53,10 +64,10 @@ impl PubSubFeature { fn get_relay(&mut self, ctx: &FeatureContext, relay_id: RelayId, auto_create: bool) -> Option<&mut Box> { if !self.relays.contains_key(&relay_id) && auto_create { let relay: Box = if ctx.node_id == relay_id.1 { - log::info!("[PubSubFeatureController] Creating new LocalRelay: {}", relay_id.0); + log::info!("[PubSubFeatureController] Creating new LocalRelay: {:?}", relay_id); Box::new(LocalRelay::default()) } else { - log::info!("[PubSubFeatureController] Creating new RemoteController: {:?}", relay_id); + log::info!("[PubSubFeatureController] Creating new RemoteRelay: {:?}", relay_id); Box::new(RemoteRelay::new(ctx.session)) }; self.relays.insert(relay_id, relay); @@ -64,16 +75,74 @@ impl PubSubFeature { self.relays.get_mut(&relay_id) } + fn get_source_hint(&mut self, node_id: NodeId, session: u64, channel: ChannelId, auto_create: bool) -> Option<&mut SourceHintLogic> { + if !self.source_hints.contains_key(&channel) && auto_create { + log::info!("[PubSubFeatureController] Creating new SourceHintLogic: {}", channel); + self.source_hints.insert(channel, SourceHintLogic::new(node_id, session)); + } + self.source_hints.get_mut(&channel) + } + fn on_local(&mut self, ctx: &FeatureContext, now: u64, actor: FeatureControlActor, channel: ChannelId, control: ChannelControl) { match control { + ChannelControl::SubAuto => { + log::info!("[PubSubFeatureController] SubAuto for {} from {:?}", channel, actor); + let sh = self.get_source_hint(ctx.node_id, ctx.session, channel, true).expect("Should create"); + sh.on_local(now, actor, source_hint::LocalCmd::Subscribe); + self.pop_single_source_hint(ctx, now, channel); + } + ChannelControl::UnsubAuto => { + log::info!("[PubSubFeatureController] UnsubAuto for {} from {:?}", channel, actor); + if let Some(sh) = self.get_source_hint(ctx.node_id, ctx.session, channel, false) { + sh.on_local(now, actor, source_hint::LocalCmd::Unsubscribe); + self.pop_single_source_hint(ctx, now, channel); + } + } + ChannelControl::PubStart => { + log::info!("[PubSubFeatureController] PubStart for {} from {:?}", channel, actor); + let relay_id = RelayId(channel, ctx.node_id); + let relay = self.get_relay(ctx, relay_id, true).expect("Should create"); + relay.on_pub_start(actor); + Self::pop_single_relay(relay_id, self.relays.get_mut(&relay_id).expect("Should have"), &mut self.queue); + + let sh = self.get_source_hint(ctx.node_id, ctx.session, channel, true).expect("Should create"); + sh.on_local(now, actor, source_hint::LocalCmd::Register); + self.pop_single_source_hint(ctx, now, channel); + } + ChannelControl::PubStop => { + log::info!("[PubSubFeatureController] PubStop for {} from {:?}", channel, actor); + let relay_id = RelayId(channel, ctx.node_id); + if let Some(relay) = self.relays.get_mut(&relay_id) { + relay.on_pub_stop(actor); + Self::pop_single_relay(relay_id, self.relays.get_mut(&relay_id).expect("Should have"), &mut self.queue); + } + + if let Some(sh) = self.get_source_hint(ctx.node_id, ctx.session, channel, false) { + sh.on_local(now, actor, source_hint::LocalCmd::Unregister); + self.pop_single_source_hint(ctx, now, channel); + } + } ChannelControl::SubSource(source) => { + log::info!("[PubSubFeatureController] SubSource(source) for {} from {:?}", channel, actor); let relay_id = RelayId(channel, source); let relay = self.get_relay(ctx, relay_id, true).expect("Should create"); log::debug!("[PubSubFeatureController] Sub for {:?} from {:?}", relay_id, actor); relay.on_local_sub(now, actor); Self::pop_single_relay(relay_id, self.relays.get_mut(&relay_id).expect("Should have"), &mut self.queue); } + ChannelControl::FeedbackAuto(fb) => { + if let Some(sh) = self.get_source_hint(ctx.node_id, ctx.session, channel, false) { + for source in sh.sources() { + let relay_id = RelayId(channel, source); + let relay = self.get_relay(ctx, relay_id, true).expect("Should create"); + log::debug!("[PubSubFeatureController] Feedback for {:?} from {:?}", relay_id, actor); + relay.on_local_feedback(now, actor, fb); + Self::pop_single_relay(relay_id, self.relays.get_mut(&relay_id).expect("Should have"), &mut self.queue); + } + } + } ChannelControl::UnsubSource(source) => { + log::info!("[PubSubFeatureController] UnsubSource(source) for {} from {:?}", channel, actor); let relay_id = RelayId(channel, source); if let Some(relay) = self.relays.get_mut(&relay_id) { log::debug!("[PubSubFeatureController] Unsub for {:?} from {:?}", relay_id, actor); @@ -113,9 +182,9 @@ impl PubSubFeature { } } - fn on_remote(&mut self, ctx: &FeatureContext, now: u64, remote: SocketAddr, relay_id: RelayId, control: RelayControl) { + fn on_remote_relay_control(&mut self, ctx: &FeatureContext, now: u64, remote: SocketAddr, relay_id: RelayId, control: RelayControl) { if let Some(_) = self.get_relay(ctx, relay_id, control.should_create()) { - let relay = self.relays.get_mut(&relay_id).expect("Should have relay"); + let relay: &mut Box = self.relays.get_mut(&relay_id).expect("Should have relay"); log::debug!("[PubSubFeatureController] Remote control for {:?} from {:?}: {:?}", relay_id, remote, control); relay.on_remote(now, remote, control); Self::pop_single_relay(relay_id, relay, &mut self.queue); @@ -127,18 +196,65 @@ impl PubSubFeature { } } + fn on_remote_source_hint_control(&mut self, ctx: &FeatureContext, now: u64, remote: SocketAddr, channel: ChannelId, control: SourceHint) { + if let Some(sh) = self.get_source_hint(ctx.node_id, ctx.session, channel, control.should_create()) { + log::debug!("[PubSubFeatureController] SourceHint control for {:?} from {:?}: {:?}", channel, remote, control); + sh.on_remote(now, remote, control); + self.pop_single_source_hint(ctx, now, channel); + + let sh = self.source_hints.get_mut(&channel).expect("Should have source hint"); + if sh.should_clear() { + self.source_hints.remove(&channel); + } + } else { + log::warn!("[PubSubFeatureController] Remote control for unknown channel {:?}", channel); + } + } + fn pop_single_relay(relay_id: RelayId, relay: &mut Box, queue: &mut VecDeque>) { while let Some(control) = relay.pop_output() { - queue.push_back(match control { - GenericRelayOutput::ToWorker(control) => FeatureOutput::ToWorker(true, ToWorker::RelayWorkerControl(relay_id, control)), - GenericRelayOutput::RouteChanged(actor) => FeatureOutput::Event(actor, Event(relay_id.0, ChannelEvent::RouteChanged(relay_id.1))), - }); + match control { + GenericRelayOutput::ToWorker(control) => queue.push_back(FeatureOutput::ToWorker(true, ToWorker::RelayControl(relay_id, control))), + GenericRelayOutput::RouteChanged(actor) => queue.push_back(FeatureOutput::Event(actor, Event(relay_id.0, ChannelEvent::RouteChanged(relay_id.1)))), + GenericRelayOutput::Feedback(actors, fb) => { + log::debug!("[PubsubController] Feedback for {:?} {:?} to actors {:?}", relay_id, fb, actors); + for actor in actors { + queue.push_back(FeatureOutput::Event(actor, Event(relay_id.0, ChannelEvent::FeedbackData(fb.clone())))); + } + } + }; + } + } + + fn pop_single_source_hint(&mut self, ctx: &FeatureContext, now: u64, channel: ChannelId) { + loop { + let sh = self.source_hints.get_mut(&channel).expect("Should have source hint"); + let out = if let Some(out) = sh.pop_output() { + out + } else { + return; + }; + match out { + source_hint::Output::SendRemote(dest, control) => { + self.queue.push_back(FeatureOutput::ToWorker(true, ToWorker::SourceHint(channel, dest, control))); + } + source_hint::Output::SubscribeSource(actors, source) => { + for actor in actors { + self.on_local(ctx, now, actor, channel, ChannelControl::SubSource(source)); + } + } + source_hint::Output::UnsubscribeSource(actors, source) => { + for actor in actors { + self.on_local(ctx, now, actor, channel, ChannelControl::UnsubSource(source)); + } + } + } } } } impl Feature for PubSubFeature { - fn on_shared_input(&mut self, _ctx: &FeatureContext, now: u64, input: FeatureSharedInput) { + fn on_shared_input(&mut self, ctx: &FeatureContext, now: u64, input: FeatureSharedInput) { match input { FeatureSharedInput::Tick(_) => { let mut clears = vec![]; @@ -153,6 +269,23 @@ impl Feature for PubSubFeature { for relay_id in clears { self.relays.remove(&relay_id); } + + let mut clears = vec![]; + let mut not_clears = vec![]; + for (channel, sh) in self.source_hints.iter_mut() { + if sh.should_clear() { + clears.push(*channel); + } else { + sh.on_tick(now); + not_clears.push(*channel); + } + } + for channel in clears { + self.source_hints.remove(&channel); + } + for channel in not_clears { + self.pop_single_source_hint(ctx, now, channel); + } } FeatureSharedInput::Connection(event) => match event { ConnectionEvent::Disconnected(ctx) => { @@ -168,8 +301,11 @@ impl Feature for PubSubFeature { fn on_input<'a>(&mut self, ctx: &FeatureContext, now_ms: u64, input: FeatureInput<'a, Control, ToController>) { match input { - FeatureInput::FromWorker(ToController::RemoteControl(remote, relay_id, control)) => { - self.on_remote(ctx, now_ms, remote, relay_id, control); + FeatureInput::FromWorker(ToController::RelayControl(remote, relay_id, control)) => { + self.on_remote_relay_control(ctx, now_ms, remote, relay_id, control); + } + FeatureInput::FromWorker(ToController::SourceHint(remote, channel, control)) => { + self.on_remote_source_hint_control(ctx, now_ms, remote, channel, control); } FeatureInput::Control(actor, Control(channel, control)) => { self.on_local(ctx, now_ms, actor, channel, control); diff --git a/packages/network/src/features/pubsub/controller/feedbacks.rs b/packages/network/src/features/pubsub/controller/feedbacks.rs new file mode 100644 index 00000000..a6314d15 --- /dev/null +++ b/packages/network/src/features/pubsub/controller/feedbacks.rs @@ -0,0 +1,211 @@ +use std::{collections::VecDeque, net::SocketAddr}; + +use crate::{base::FeatureControlActor, features::pubsub::msg::Feedback}; + +#[derive(Debug, PartialEq, Eq)] +enum FeedbackSource { + Local(FeatureControlActor), + Remote(SocketAddr), +} + +#[derive(Debug, Default)] +struct SingleFeedbackKind { + kind: u8, + feedbacks: Vec<(FeedbackSource, Feedback, u64)>, + feedbacks_updated: bool, + last_feedback_ts: Option, +} + +impl SingleFeedbackKind { + fn on_local_feedback(&mut self, now: u64, actor: FeatureControlActor, fb: Feedback) { + self.feedbacks_updated = true; + let source = FeedbackSource::Local(actor); + if let Some(index) = self.feedbacks.iter().position(|(a, _, _)| *a == source) { + self.feedbacks[index] = (source, fb, now); + } else { + self.feedbacks.push((source, fb, now)); + } + } + + fn on_remote_feedback(&mut self, now: u64, remote: SocketAddr, fb: Feedback) { + self.feedbacks_updated = true; + let source = FeedbackSource::Remote(remote); + if let Some(index) = self.feedbacks.iter().position(|(a, _, _)| *a == source) { + self.feedbacks[index] = (source, fb, now); + } else { + self.feedbacks.push((source, fb, now)); + } + } + + fn process_feedbacks(&mut self, now: u64) -> Option { + if !self.feedbacks_updated { + self.feedbacks.retain(|(_, fb, last_ts)| now < last_ts + fb.timeout_ms as u64); + return None; + } + self.feedbacks_updated = false; + log::debug!("[FeedbacksAggerator] on process feedback for kind {}", self.kind); + let mut aggerated_fb: Option = None; + for (_, fb, _) in &self.feedbacks { + if let Some(aggerated_fb) = &mut aggerated_fb { + *aggerated_fb = aggerated_fb.clone() + fb.clone(); + } else { + aggerated_fb = Some(fb.clone()); + } + } + self.feedbacks.retain(|(_, fb, last_ts)| now < last_ts + fb.timeout_ms as u64); + + if let Some(last_fb_ts) = self.last_feedback_ts { + if let Some(fb) = &aggerated_fb { + if now < last_fb_ts + fb.interval_ms as u64 { + return None; + } + } + } + aggerated_fb + } +} + +#[derive(Debug, Default)] +pub struct FeedbacksAggerator { + feedbacks: Vec, + queue: VecDeque, +} + +impl FeedbacksAggerator { + pub fn on_tick(&mut self, now: u64) { + log::debug!("[FeedbacksAggerator] on tick {now}"); + self.process_feedbacks(now); + } + + pub fn on_local_feedback(&mut self, now: u64, actor: FeatureControlActor, fb: Feedback) { + log::debug!("[FeedbacksAggerator] on local_feedback from {:?} {:?}", actor, fb); + let kind = self.get_fb_kind(fb.kind); + kind.on_local_feedback(now, actor, fb); + if let Some(fb) = kind.process_feedbacks(now) { + self.queue.push_back(fb); + } + } + + pub fn on_remote_feedback(&mut self, now: u64, remote: SocketAddr, fb: Feedback) { + log::debug!("[FeedbacksAggerator] on remote_feedback from {} {:?}", remote, fb); + let kind = self.get_fb_kind(fb.kind); + kind.on_remote_feedback(now, remote, fb); + if let Some(fb) = kind.process_feedbacks(now) { + self.queue.push_back(fb); + } + } + + pub fn pop_output(&mut self) -> Option { + self.queue.pop_front() + } + + fn process_feedbacks(&mut self, now: u64) { + for kind in &mut self.feedbacks { + while let Some(fb) = kind.process_feedbacks(now) { + self.queue.push_back(fb); + } + } + self.feedbacks.retain(|x| !x.feedbacks.is_empty()); + } + + fn get_fb_kind(&mut self, kind: u8) -> &mut SingleFeedbackKind { + if let Some(index) = self.feedbacks.iter().position(|x| x.kind == kind) { + &mut self.feedbacks[index] + } else { + let new = SingleFeedbackKind { + kind, + feedbacks: Vec::new(), + feedbacks_updated: false, + last_feedback_ts: None, + }; + self.feedbacks.push(new); + self.feedbacks.last_mut().expect("Should got last element") + } + } +} + +#[cfg(test)] +mod test { + use crate::base::FeatureControlActor; + + use super::{Feedback, FeedbacksAggerator}; + + #[test] + fn aggerator_single() { + let mut aggerator = FeedbacksAggerator::default(); + let fb = Feedback::simple(0, 10, 1000, 2000); + aggerator.on_local_feedback(1, FeatureControlActor::Controller, fb); + aggerator.on_tick(2); + assert_eq!(aggerator.pop_output(), Some(fb)); + assert_eq!(aggerator.pop_output(), None); + } + + #[test] + fn aggerator_single_rewrite() { + let mut aggerator = FeedbacksAggerator::default(); + let fb = Feedback::simple(0, 10, 1000, 2000); + aggerator.on_local_feedback(0, FeatureControlActor::Controller, fb); + assert_eq!(aggerator.pop_output(), Some(fb)); + assert_eq!(aggerator.pop_output(), None); + + let fb = Feedback::simple(0, 11, 1000, 2000); + aggerator.on_local_feedback(2, FeatureControlActor::Controller, fb); + aggerator.on_tick(2000); + assert_eq!(aggerator.pop_output(), Some(fb)); + assert_eq!(aggerator.pop_output(), None); + } + + #[test] + fn aggerator_multi_sources() { + let mut aggerator = FeedbacksAggerator::default(); + let fb = Feedback::simple(0, 10, 1000, 2000); + aggerator.on_local_feedback(1, FeatureControlActor::Controller, fb); + assert_eq!(aggerator.pop_output(), Some(fb)); + assert_eq!(aggerator.pop_output(), None); + + let fb = Feedback::simple(0, 20, 1500, 3000); + aggerator.on_local_feedback(2, FeatureControlActor::Worker(0), fb); + + aggerator.on_tick(1000); + assert_eq!( + aggerator.pop_output(), + Some(Feedback { + kind: 0, + count: 2, + max: 20, + min: 10, + sum: 30, + interval_ms: 1000, + timeout_ms: 3000, + }) + ); + assert_eq!(aggerator.pop_output(), None); + } + + #[test] + fn aggerator_multi_types() { + let mut aggerator = FeedbacksAggerator::default(); + let fb1 = Feedback::simple(0, 10, 1000, 2000); + aggerator.on_local_feedback(1, FeatureControlActor::Controller, fb1); + + let fb2 = Feedback::simple(1, 20, 1500, 3000); + aggerator.on_local_feedback(2, FeatureControlActor::Controller, fb2); + + aggerator.on_tick(3); + assert_eq!(aggerator.pop_output(), Some(fb1)); + assert_eq!(aggerator.pop_output(), Some(fb2)); + assert_eq!(aggerator.pop_output(), None); + } + + #[test] + fn aggerator_auto_clear_kind_nodata() { + let mut aggerator = FeedbacksAggerator::default(); + let fb1 = Feedback::simple(0, 10, 1000, 2000); + aggerator.on_local_feedback(0, FeatureControlActor::Controller, fb1); + + assert_eq!(aggerator.feedbacks.len(), 1); + + aggerator.on_tick(2000); + assert_eq!(aggerator.feedbacks.len(), 0); + } +} diff --git a/packages/network/src/features/pubsub/controller/local_relay.rs b/packages/network/src/features/pubsub/controller/local_relay.rs index 8ef89ba8..53f71a2a 100644 --- a/packages/network/src/features/pubsub/controller/local_relay.rs +++ b/packages/network/src/features/pubsub/controller/local_relay.rs @@ -1,35 +1,65 @@ -use crate::{base::FeatureControlActor, features::pubsub::msg::RelayControl}; +use std::net::SocketAddr; -use super::{consumers::RelayConsumers, GenericRelay, GenericRelayOutput}; +use crate::{ + base::FeatureControlActor, + features::pubsub::msg::{Feedback, RelayControl}, +}; + +use super::{consumers::RelayConsumers, feedbacks::FeedbacksAggerator, GenericRelay, GenericRelayOutput}; #[derive(Default)] pub struct LocalRelay { consumers: RelayConsumers, + feedbacks: FeedbacksAggerator, + publishers: Vec, } impl GenericRelay for LocalRelay { fn on_tick(&mut self, now: u64) { + self.feedbacks.on_tick(now); self.consumers.on_tick(now); } + fn on_pub_start(&mut self, actor: FeatureControlActor) { + if !self.publishers.contains(&actor) { + log::debug!("[LocalRelay] on_pub_start {:?}", actor); + self.publishers.push(actor); + } + } + + fn on_pub_stop(&mut self, actor: FeatureControlActor) { + if let Some(index) = self.publishers.iter().position(|x| *x == actor) { + log::debug!("[LocalRelay] on_pub_stop {:?}", actor); + self.publishers.swap_remove(index); + } + } + fn on_local_sub(&mut self, now: u64, actor: FeatureControlActor) { self.consumers.on_local_sub(now, actor); } + fn on_local_feedback(&mut self, now: u64, actor: FeatureControlActor, feedback: Feedback) { + self.feedbacks.on_local_feedback(now, actor, feedback); + } + fn on_local_unsub(&mut self, now: u64, actor: FeatureControlActor) { self.consumers.on_local_unsub(now, actor); } - fn on_remote(&mut self, now: u64, remote: std::net::SocketAddr, control: RelayControl) { - self.consumers.on_remote(now, remote, control); + fn on_remote(&mut self, now: u64, remote: SocketAddr, control: RelayControl) { + if let RelayControl::Feedback(fb) = control { + self.feedbacks.on_remote_feedback(now, remote, fb); + } else { + self.consumers.on_remote(now, remote, control); + } } - fn conn_disconnected(&mut self, now: u64, remote: std::net::SocketAddr) { + fn conn_disconnected(&mut self, now: u64, remote: SocketAddr) { self.consumers.conn_disconnected(now, remote); } fn should_clear(&self) -> bool { - self.consumers.should_clear() + self.consumers.should_clear() && self.publishers.is_empty() } fn relay_dests(&self) -> Option<(&[FeatureControlActor], bool)> { @@ -37,6 +67,10 @@ impl GenericRelay for LocalRelay { } fn pop_output(&mut self) -> Option { + if let Some(fb) = self.feedbacks.pop_output() { + log::debug!("[LocalRelay] pop_output feedback {:?}", fb); + return Some(GenericRelayOutput::Feedback(self.publishers.clone(), fb)); + } self.consumers.pop_output().map(GenericRelayOutput::ToWorker) } } diff --git a/packages/network/src/features/pubsub/controller/remote_relay.rs b/packages/network/src/features/pubsub/controller/remote_relay.rs index 9826a11a..5c084ac1 100644 --- a/packages/network/src/features/pubsub/controller/remote_relay.rs +++ b/packages/network/src/features/pubsub/controller/remote_relay.rs @@ -6,17 +6,31 @@ use crate::{ base::FeatureControlActor, - features::pubsub::{msg::RelayControl, RelayWorkerControl}, + features::pubsub::{ + msg::{Feedback, RelayControl}, + RelayWorkerControl, + }, }; use std::{collections::VecDeque, net::SocketAddr}; -use super::{consumers::RelayConsumers, GenericRelay, GenericRelayOutput, RELAY_STICKY_MS, RELAY_TIMEOUT}; +use super::{consumers::RelayConsumers, feedbacks::FeedbacksAggerator, GenericRelay, GenericRelayOutput, RELAY_STICKY_MS, RELAY_TIMEOUT}; enum RelayState { New, - Binding { consumers: RelayConsumers }, - Bound { consumers: RelayConsumers, next: SocketAddr, sticky_session_at: u64 }, - Unbinding { next: SocketAddr, started_at: u64 }, + Binding { + consumers: RelayConsumers, + feedbacks: FeedbacksAggerator, + }, + Bound { + consumers: RelayConsumers, + feedbacks: FeedbacksAggerator, + next: SocketAddr, + sticky_session_at: u64, + }, + Unbinding { + next: SocketAddr, + started_at: u64, + }, Unbound, } @@ -40,13 +54,23 @@ impl RemoteRelay { queue.push_back(GenericRelayOutput::ToWorker(control)); } } + + fn pop_feedbacks_out(remote: SocketAddr, feedbacks: &mut FeedbacksAggerator, queue: &mut VecDeque) { + while let Some(fb) = feedbacks.pop_output() { + queue.push_back(GenericRelayOutput::ToWorker(RelayWorkerControl::SendFeedback(fb, remote))); + } + } } impl GenericRelay for RemoteRelay { fn on_tick(&mut self, now: u64) { match &mut self.state { RelayState::Bound { - next, consumers, sticky_session_at, .. + next, + consumers, + feedbacks, + sticky_session_at, + .. } => { if now >= *sticky_session_at + RELAY_STICKY_MS { log::info!("[PubSubRemoteRelay] Sticky session end for relay from {next} => trying finding better way"); @@ -55,7 +79,9 @@ impl GenericRelay for RemoteRelay { self.queue.push_back(GenericRelayOutput::ToWorker(RelayWorkerControl::SendSub(self.uuid, Some(*next)))); } consumers.on_tick(now); + feedbacks.on_tick(now); Self::pop_consumers_out(consumers, &mut self.queue); + Self::pop_feedbacks_out(*next, feedbacks, &mut self.queue); if consumers.should_clear() { self.queue.push_back(GenericRelayOutput::ToWorker(RelayWorkerControl::SendUnsub(self.uuid, *next))); @@ -85,13 +111,14 @@ impl GenericRelay for RemoteRelay { fn conn_disconnected(&mut self, now: u64, remote: SocketAddr) { match &mut self.state { - RelayState::Bound { consumers, next, .. } => { + RelayState::Bound { consumers, feedbacks, next, .. } => { consumers.conn_disconnected(now, remote); Self::pop_consumers_out(consumers, &mut self.queue); // If remote is next, this will not be consumers, because it will cause loop deps if *next == remote { let consumers = std::mem::replace(consumers, Default::default()); - self.state = RelayState::Binding { consumers }; + let feedbacks = std::mem::replace(feedbacks, Default::default()); + self.state = RelayState::Binding { consumers, feedbacks }; self.queue.push_back(GenericRelayOutput::ToWorker(RelayWorkerControl::SendSub(self.uuid, None))); } else if consumers.should_clear() { self.queue.push_back(GenericRelayOutput::ToWorker(RelayWorkerControl::SendUnsub(self.uuid, *next))); @@ -111,6 +138,14 @@ impl GenericRelay for RemoteRelay { } } + fn on_pub_start(&mut self, _actor: FeatureControlActor) { + panic!("Should not be called"); + } + + fn on_pub_stop(&mut self, _actor: FeatureControlActor) { + panic!("Should not be called"); + } + /// Add a local subscriber to the relay /// Returns true if this is the first subscriber, false otherwise fn on_local_sub(&mut self, now: u64, actor: FeatureControlActor) { @@ -118,18 +153,20 @@ impl GenericRelay for RemoteRelay { RelayState::New | RelayState::Unbound => { log::info!("[PubSubRemoteRelay] Sub in New or Unbound state => switch to Binding and send Sub message"); let mut consumers = RelayConsumers::default(); + let feedbacks = FeedbacksAggerator::default(); consumers.on_local_sub(now, actor); Self::pop_consumers_out(&mut consumers, &mut self.queue); - self.state = RelayState::Binding { consumers }; + self.state = RelayState::Binding { consumers, feedbacks }; self.queue.push_back(GenericRelayOutput::ToWorker(RelayWorkerControl::SendSub(self.uuid, None))); } RelayState::Unbinding { next, .. } => { log::debug!("[PubSubRemoteRelay] Sub in Unbinding state => switch to Binding with previous next {next}"); let mut consumers = RelayConsumers::default(); + let feedbacks = FeedbacksAggerator::default(); consumers.on_local_sub(now, actor); Self::pop_consumers_out(&mut consumers, &mut self.queue); self.queue.push_back(GenericRelayOutput::ToWorker(RelayWorkerControl::SendSub(self.uuid, Some(*next)))); - self.state = RelayState::Binding { consumers }; + self.state = RelayState::Binding { consumers, feedbacks }; } RelayState::Binding { consumers, .. } | RelayState::Bound { consumers, .. } => { log::debug!("[PubSubRemoteRelay] Sub in Binding or Bound state => just add to list"); @@ -139,6 +176,20 @@ impl GenericRelay for RemoteRelay { } } + /// Sending feedback to sources, for avoiding wasting bandwidth, the feedback will be aggregated and send in each window_ms + fn on_local_feedback(&mut self, now: u64, actor: FeatureControlActor, feedback: Feedback) { + match &mut self.state { + RelayState::Binding { feedbacks, .. } => { + feedbacks.on_local_feedback(now, actor, feedback); + } + RelayState::Bound { next, feedbacks, .. } => { + feedbacks.on_local_feedback(now, actor, feedback); + Self::pop_feedbacks_out(*next, feedbacks, &mut self.queue); + } + _ => {} + } + } + /// Remove a local subscriber from the relay /// Returns true if this is the last subscriber, false otherwise fn on_local_unsub(&mut self, now: u64, actor: FeatureControlActor) { @@ -176,22 +227,26 @@ impl GenericRelay for RemoteRelay { return; } match &mut self.state { - RelayState::Binding { consumers } => { + RelayState::Binding { consumers, feedbacks } => { log::info!("[Relay] SubOK for binding relay {} from {remote} => switched to Bound with this remote", self.uuid); self.queue.push_back(GenericRelayOutput::ToWorker(RelayWorkerControl::RouteSetSource(remote))); let consumers = std::mem::replace(consumers, Default::default()); + let feedbacks = std::mem::replace(feedbacks, Default::default()); self.state = RelayState::Bound { next: remote, consumers, + feedbacks, sticky_session_at: now, }; } - RelayState::Bound { next, sticky_session_at, consumers } => { + RelayState::Bound { + next, sticky_session_at, consumers, .. + } => { if *next == remote { - log::info!("[Relay] SubOK for bound relay {} from same remote {remote} => renew sticky session", self.uuid); + log::debug!("[Relay] SubOK for bound relay {} from same remote {remote} => renew sticky session", self.uuid); *sticky_session_at = now; } else { - log::info!("[Relay] SubOK for bound relay {} from other remote {remote} => renew stick session and Unsub older", self.uuid); + log::warn!("[Relay] SubOK for bound relay {} from other remote {remote} => renew stick session and Unsub older", self.uuid); let (locals, has_remote) = consumers.relay_dests(); self.queue.push_back(GenericRelayOutput::ToWorker(RelayWorkerControl::SendUnsub(self.uuid, *next))); if has_remote { @@ -238,20 +293,32 @@ impl GenericRelay for RemoteRelay { _ => {} } } + RelayControl::Feedback(fb) => match &mut self.state { + RelayState::Binding { feedbacks, .. } => { + feedbacks.on_remote_feedback(now, remote, fb); + } + RelayState::Bound { next, feedbacks, .. } => { + feedbacks.on_remote_feedback(now, remote, fb); + Self::pop_feedbacks_out(*next, feedbacks, &mut self.queue); + } + _ => {} + }, _ => match &mut self.state { RelayState::New | RelayState::Unbound => { let mut consumers = RelayConsumers::default(); + let feedbacks = FeedbacksAggerator::default(); consumers.on_remote(now, remote, control); Self::pop_consumers_out(&mut consumers, &mut self.queue); self.queue.push_back(GenericRelayOutput::ToWorker(RelayWorkerControl::SendSub(self.uuid, None))); - self.state = RelayState::Binding { consumers }; + self.state = RelayState::Binding { consumers, feedbacks }; } RelayState::Unbinding { next, .. } => { let mut consumers = RelayConsumers::default(); + let feedbacks = FeedbacksAggerator::default(); consumers.on_remote(now, remote, control); Self::pop_consumers_out(&mut consumers, &mut self.queue); self.queue.push_back(GenericRelayOutput::ToWorker(RelayWorkerControl::SendSub(self.uuid, Some(*next)))); - self.state = RelayState::Binding { consumers }; + self.state = RelayState::Binding { consumers, feedbacks }; } RelayState::Binding { consumers, .. } => { consumers.on_remote(now, remote, control); diff --git a/packages/network/src/features/pubsub/controller/source_hint.rs b/packages/network/src/features/pubsub/controller/source_hint.rs new file mode 100644 index 00000000..4b62b6cc --- /dev/null +++ b/packages/network/src/features/pubsub/controller/source_hint.rs @@ -0,0 +1,808 @@ +//! Source Hint is a way we create a notification tree for sources and subscribers. +//! Because if this method is not latency focus thefore it only used for low frequency message like Source Changed notification. +//! +//! The main idea is instead of a single node take care of all subscribers, which can cause overload and waste of resource, +//! we create a tree of nodes that relay the message to the next hop. +//! +//! Term: root node is the node which closest to channel id in XOR distance. +//! +//! Subscribe: a node need to subscribe a channel notification, it will send a message to next node. +//! Each time a node receive a subscribe message it will add the subscriber to the list and send a message to next hop. +//! +//! Register: a node need to register as a source for channel, it will send a message to next node. +//! Each time a node receive a register message it will add the source to the list and send: +//! - to all subscribers except sender if it is a new source. +//! - to next hop +//! +//! For ensure keep the tree clean and sync, each Subscriber and Source will resend message in each 1 seconds for keep alive. +//! For solve the case which network chaged cause root node changed, a node receive Register or Subscribe will reply with a RegisterOk, SubscribeOk +//! for that, the sender will know the real next hop and only accept notification from selected node. + +use std::{ + collections::{BTreeMap, VecDeque}, + net::SocketAddr, +}; + +use atm0s_sdn_identity::NodeId; + +use crate::{base::FeatureControlActor, features::pubsub::msg::SourceHint}; + +const TIMEOUT_MS: u64 = 10_000; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum LocalCmd { + Register, + Unregister, + Subscribe, + Unsubscribe, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Output { + SendRemote(Option, SourceHint), + SubscribeSource(Vec, NodeId), + UnsubscribeSource(Vec, NodeId), +} + +#[derive(Default)] +pub struct SourceHintLogic { + node_id: NodeId, + session_id: u64, + remote_sources: BTreeMap, + remote_subscribers: BTreeMap, + local_sources: Vec, + local_subscribers: Vec, + next_hop: Option, + queue: VecDeque, +} + +impl SourceHintLogic { + pub fn new(node_id: NodeId, session_id: u64) -> Self { + Self { + node_id, + session_id, + ..Default::default() + } + } + + /// return sources as local and remote sources + pub fn sources(&self) -> Vec { + if self.local_sources.is_empty() { + self.remote_sources.keys().cloned().collect() + } else { + self.remote_sources.keys().chain(std::iter::once(&self.node_id)).cloned().collect() + } + } + + pub fn on_tick(&mut self, now_ms: u64) { + let mut timeout_subscribes = vec![]; + for (remote, last_tick) in &self.remote_subscribers { + if now_ms - last_tick >= TIMEOUT_MS { + timeout_subscribes.push(*remote); + } + } + for subscriber in timeout_subscribes { + self.remote_subscribers.remove(&subscriber); + // if all subscribers are removed, we need to notify next hop to remove this node from subscriber list + if self.local_subscribers.is_empty() && self.remote_subscribers.is_empty() { + log::warn!("[SourceHint] Send Unsubscribe({}) to next node because remote subscriber {subscriber} timeout", self.session_id); + self.queue.push_back(Output::SendRemote(None, SourceHint::Unsubscribe(self.session_id))); + } + } + + let mut timeout_sources = vec![]; + for (source, last_tick) in &self.remote_sources { + if now_ms - last_tick >= TIMEOUT_MS { + timeout_sources.push(*source); + } + } + for source in timeout_sources { + self.remote_sources.remove(&source); + if !self.local_subscribers.is_empty() { + log::warn!("[SourceHint] Notify remove source({source}) to local {:?} actors because timeout", self.local_subscribers); + self.queue.push_back(Output::UnsubscribeSource(self.local_subscribers.clone(), source)); + } + } + + if !self.local_sources.is_empty() { + log::debug!("[SourceHint] ReSend Register({}) to root node", self.node_id); + self.queue.push_back(Output::SendRemote(None, SourceHint::Register { source: self.node_id, to_root: true })); + for (remote, _) in &self.remote_subscribers { + log::debug!("[SourceHint] ReSend Register({}) to subscribe {remote} node", self.node_id); + self.queue.push_back(Output::SendRemote(Some(*remote), SourceHint::Register { source: self.node_id, to_root: false })); + } + } + + if !self.local_subscribers.is_empty() || !self.remote_subscribers.is_empty() { + log::debug!("[SourceHint] ReSend Subscribe({}) to next node", self.session_id); + self.queue.push_back(Output::SendRemote(None, SourceHint::Subscribe(self.session_id))); + } + } + + pub fn on_local(&mut self, _now_ms: u64, actor: FeatureControlActor, cmd: LocalCmd) { + match cmd { + LocalCmd::Register => { + if !self.local_sources.contains(&actor) { + log::info!("[SourceHint] Register new local source: {:?}", actor); + self.local_sources.push(actor); + if self.local_sources.len() == 1 && self.remote_sources.is_empty() { + log::info!("[SourceHint] Send Register({}) to root node", self.node_id); + self.queue.push_back(Output::SendRemote(None, SourceHint::Register { source: self.node_id, to_root: true })); + } + + if self.local_sources.len() == 1 { + for (remote, _) in &self.remote_subscribers { + log::info!("[SourceHint] Notify Register({}) to subscribe {remote} node", self.node_id); + self.queue.push_back(Output::SendRemote(Some(*remote), SourceHint::Register { source: self.node_id, to_root: false })); + } + if !self.local_subscribers.is_empty() { + log::info!("[SourceHint] Notify new source({}) to local {:?} actors", self.node_id, self.local_subscribers); + self.queue.push_back(Output::SubscribeSource(self.local_subscribers.clone(), self.node_id)); + } + } + } + } + LocalCmd::Unregister => { + if let Some(index) = self.local_sources.iter().position(|x| x == &actor) { + log::info!("[SourceHint] Unregister local source: {:?}", actor); + self.local_sources.swap_remove(index); + if self.local_sources.is_empty() && self.remote_sources.is_empty() { + log::info!("[SourceHint] Send Unregister({}) to root node", self.node_id); + self.queue.push_back(Output::SendRemote(None, SourceHint::Unregister { source: self.node_id, to_root: true })); + } + + if self.local_sources.is_empty() { + for (remote, _) in &self.remote_subscribers { + log::info!("[SourceHint] Notify Unregister({}) to subscribe {remote} node", self.node_id); + self.queue.push_back(Output::SendRemote(Some(*remote), SourceHint::Unregister { source: self.node_id, to_root: false })); + } + if !self.local_subscribers.is_empty() { + log::info!("[SourceHint] Notify removed source({}) to local {:?} actors", self.node_id, self.local_subscribers); + self.queue.push_back(Output::UnsubscribeSource(self.local_subscribers.clone(), self.node_id)); + } + } + } + } + LocalCmd::Subscribe => { + if !self.local_subscribers.contains(&actor) { + log::info!("[SourceHint] Subscribe new local subscriber: {:?}", actor); + self.local_subscribers.push(actor); + if self.local_subscribers.len() == 1 && self.remote_subscribers.is_empty() { + log::info!("[SourceHint] Send Subscribe({}) to root node", self.session_id); + self.queue.push_back(Output::SendRemote(None, SourceHint::Subscribe(self.session_id))); + } + + for source in self.remote_sources.keys() { + log::info!("[SourceHint] Notify SubscribeSource for already added sources({source}) to actor {:?}", actor); + self.queue.push_back(Output::SubscribeSource(vec![actor], *source)); + } + if !self.local_sources.is_empty() { + self.queue.push_back(Output::SubscribeSource(vec![actor], self.node_id)); + } + } + } + LocalCmd::Unsubscribe => { + if let Some(index) = self.local_subscribers.iter().position(|x| x == &actor) { + log::info!("[SourceHint] Unsubscribe local subscriber: {:?}", actor); + self.local_subscribers.swap_remove(index); + // if all subscribers are removed, we need to notify next hop to remove this node from subscriber list + if self.local_subscribers.is_empty() && self.remote_subscribers.is_empty() { + log::info!("[SourceHint] Send Unsubscribe({}) to next node", self.session_id); + self.queue.push_back(Output::SendRemote(None, SourceHint::Unsubscribe(self.session_id))); + } + + // we unsubscriber for the local actor from all remote sources + for source in self.remote_sources.keys() { + log::info!("[SourceHint] Notify UnsubscribeSource for already added sources({source}) to actor {:?}", actor); + self.queue.push_back(Output::UnsubscribeSource(vec![actor], *source)); + } + // we unsubscriber for the local actor from local source + if !self.local_sources.is_empty() { + self.queue.push_back(Output::UnsubscribeSource(vec![actor], self.node_id)); + } + } + } + } + } + + pub fn on_remote(&mut self, now_ms: u64, remote: SocketAddr, cmd: SourceHint) { + match cmd { + SourceHint::Register { source, to_root } => { + // We only accept register from original source and go up to root node, or notify from next_hop (parent node1) + if !to_root && self.next_hop != Some(remote) { + log::warn!("[SourceHint] remote Register({source}) relayed from {remote} is not from next hop, ignore it"); + return; + } + + // Only to_root msg is climb up to root node + if to_root { + log::debug!("[SourceHint] forward remote Register({}) to root node", source); + self.queue.push_back(Output::SendRemote(None, SourceHint::Register { source, to_root: true })); + } + + for (subscriber, _) in &self.remote_subscribers { + if subscriber.eq(&remote) { + // for avoiding loop, we only send to other subscribers except sender + continue; + } + log::debug!("[SourceHint] forward remote Register({}) to subscribe {subscriber} node", source); + self.queue.push_back(Output::SendRemote(Some(*subscriber), SourceHint::Register { source, to_root: false })); + } + + // if source is new, notify to all local subscribers + if self.remote_sources.insert(source, now_ms).is_none() { + log::info!("[SourceHint] added remote source {source}"); + if !self.local_subscribers.is_empty() { + log::info!("[SourceHint] Notify new source({}) to local {:?} actors", source, self.local_subscribers); + self.queue.push_back(Output::SubscribeSource(self.local_subscribers.clone(), source)); + } + } + } + SourceHint::Unregister { source, to_root } => { + // We only accept unregister from original source and go up to root node, or notify from next_hop (parent node1 + if !to_root && self.next_hop != Some(remote) { + log::warn!("[SourceHint] Unregister({source}) relayed from {remote} is not from next hop, ignore it"); + return; + } + + // Only to_root msg is climb up to root node + if to_root { + log::debug!("[SourceHint] Send Register({}) to root node", self.node_id); + self.queue.push_back(Output::SendRemote(None, SourceHint::Unregister { source, to_root: true })); + } + + for (subscriber, _) in &self.remote_subscribers { + if subscriber.eq(&remote) { + // for avoiding loop, we only send to other subscribers except sender + continue; + } + log::debug!("[SourceHint] relay UnRegister({}) to subscribe {subscriber} node", source); + self.queue.push_back(Output::SendRemote(Some(*subscriber), SourceHint::Unregister { source, to_root: false })); + } + + // if source is deleted, notify to all local subscribers + if self.remote_sources.remove(&source).is_some() { + log::info!("[SourceHint] removed remote source {source}"); + if !self.local_subscribers.is_empty() { + log::info!("[SourceHint] Notify remove source({}) to local {:?} actors", self.node_id, self.local_subscribers); + self.queue.push_back(Output::UnsubscribeSource(self.local_subscribers.clone(), source)); + } + } + } + SourceHint::Subscribe(session) => { + if self.remote_subscribers.insert(remote, now_ms).is_none() { + log::info!("[SourceHint] added remote subscriber {remote}"); + self.queue.push_back(Output::SendRemote(Some(remote), SourceHint::SubscribeOk(session))); + if self.remote_subscribers.len() == 1 && self.local_subscribers.is_empty() { + log::info!("[SourceHint] Send Subscribe({}) to root node", self.session_id); + self.queue.push_back(Output::SendRemote(None, SourceHint::Subscribe(self.session_id))); + } + + let mut sources = self.remote_sources.keys().cloned().collect::>(); + if !self.local_sources.is_empty() { + sources.push(self.node_id); + } + + if !sources.is_empty() { + log::info!("[SourceHint] Notify SubscribeSource for already added sources({:?}) to remote {remote}", sources); + self.queue.push_back(Output::SendRemote(Some(remote), SourceHint::Sources(sources))); + } + } else { + //TODO check session for resend ACK + self.queue.push_back(Output::SendRemote(Some(remote), SourceHint::SubscribeOk(session))); + } + } + SourceHint::SubscribeOk(session) => { + if session == self.session_id { + self.next_hop = Some(remote); + } + } + SourceHint::Unsubscribe(session) => { + if self.remote_subscribers.remove(&remote).is_some() { + log::info!("[SourceHint] removed remote subscriber {remote}"); + self.queue.push_back(Output::SendRemote(Some(remote), SourceHint::UnsubscribeOk(session))); + // if all subscribers are removed, we need to notify next hop to remove this node from subscriber list + if self.local_subscribers.is_empty() && self.remote_subscribers.is_empty() { + log::info!("[SourceHint] Send Unsubscribe({}) to next node because all subscribers removed", self.session_id); + self.queue.push_back(Output::SendRemote(None, SourceHint::Unsubscribe(self.session_id))); + } + } else { + //check session for resend ACK + self.queue.push_back(Output::SendRemote(Some(remote), SourceHint::UnsubscribeOk(session))); + } + } + SourceHint::UnsubscribeOk(session) => { + if session == self.session_id { + self.next_hop = None; + } + } + SourceHint::Sources(sources) => { + for source in sources { + if self.remote_sources.insert(source, now_ms).is_none() { + log::info!("[SourceHint] added remote source {source}"); + for remote in self.remote_subscribers.keys() { + log::debug!("[SourceHint] Notify source({source}) from snapshot to remote {remote}"); + self.queue.push_back(Output::SendRemote(Some(*remote), SourceHint::Register { source, to_root: false })); + } + if !self.local_subscribers.is_empty() { + log::info!("[SourceHint] Notify new source({source}) to local {:?} actors", self.local_subscribers); + self.queue.push_back(Output::SubscribeSource(self.local_subscribers.clone(), source)); + } + } + } + } + } + } + + pub fn pop_output(&mut self) -> Option { + self.queue.pop_front() + } + + pub fn should_clear(&self) -> bool { + self.local_sources.is_empty() && self.remote_sources.is_empty() && self.local_subscribers.is_empty() && self.remote_subscribers.is_empty() + } +} + +#[cfg(test)] +mod tests { + use std::{net::SocketAddr, vec}; + + use crate::{base::FeatureControlActor, features::pubsub::controller::source_hint::TIMEOUT_MS}; + + use super::{LocalCmd, Output, SourceHint, SourceHintLogic}; + + #[test] + fn local_subscribe_should_send_event() { + let node_id = 1; + let session_id = 1234; + let mut sh = SourceHintLogic::new(node_id, session_id); + + //subscribe should send a subscribe message + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Subscribe); + + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Subscribe); + assert_eq!(sh.pop_output(), None); + + sh.on_local(0, FeatureControlActor::Worker(1), LocalCmd::Subscribe); + assert_eq!(sh.pop_output(), None); + + //fake a local source should send local source event + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Register); + + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Register { source: node_id, to_root: true }))); + assert_eq!( + sh.pop_output(), + Some(Output::SubscribeSource(vec![FeatureControlActor::Controller, FeatureControlActor::Worker(1)], node_id)) + ); + assert_eq!(sh.pop_output(), None); + } + + #[test] + fn local_subscribe_should_handle_sources() { + let node_id = 1; + let session_id = 1234; + let mut sh = SourceHintLogic::new(node_id, session_id); + + //subscribe should send a subscribe message + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Subscribe); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + + let remote = SocketAddr::new([127, 0, 0, 1].into(), 1234); + + sh.on_remote(100, remote, SourceHint::Sources(vec![2, 3])); + assert_eq!(sh.pop_output(), Some(Output::SubscribeSource(vec![FeatureControlActor::Controller], 2))); + assert_eq!(sh.pop_output(), Some(Output::SubscribeSource(vec![FeatureControlActor::Controller], 3))); + assert_eq!(sh.pop_output(), None); + } + + #[test] + fn local_register_should_send_event() { + let node_id = 1; + let session_id = 1234; + let mut sh = SourceHintLogic::new(node_id, session_id); + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Register); + + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Register { source: node_id, to_root: true }))); + assert_eq!(sh.pop_output(), None); + + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Register); + assert_eq!(sh.pop_output(), None); + + sh.on_local(0, FeatureControlActor::Worker(1), LocalCmd::Register); + assert_eq!(sh.pop_output(), None); + + //subscribe should send a subscribe message and local source event + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Subscribe); + + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), Some(Output::SubscribeSource(vec![FeatureControlActor::Controller], node_id))); + assert_eq!(sh.pop_output(), None); + } + + #[test] + fn remote_subscribe_should_send_event() { + let node_id = 1; + let session_id = 1234; + let mut sh = SourceHintLogic::new(node_id, session_id); + + let remote = SocketAddr::new([127, 0, 0, 1].into(), 1234); + let remote_session_id = 4321; + + //fake a local source + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Register); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Register { source: node_id, to_root: true }))); + + //remote subscribe should receive a subscribe ok and snapshot of sources + sh.on_remote(0, remote, SourceHint::Subscribe(remote_session_id)); + + assert_eq!(sh.pop_output(), Some(Output::SendRemote(Some(remote), SourceHint::SubscribeOk(remote_session_id)))); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(Some(remote), SourceHint::Sources(vec![node_id])))); + assert_eq!(sh.pop_output(), None); + } + + #[test] + fn remote_register_should_send_event() { + let node_id = 1; + let session_id = 1234; + let mut sh = SourceHintLogic::new(node_id, session_id); + + let remote = SocketAddr::new([127, 0, 0, 1].into(), 1234); + let remote_node_id = 2; + + //remote register should send event + sh.on_remote( + 0, + remote, + SourceHint::Register { + source: remote_node_id, + to_root: true, + }, + ); + + assert_eq!( + sh.pop_output(), + Some(Output::SendRemote( + None, + SourceHint::Register { + source: remote_node_id, + to_root: true + } + )) + ); + assert_eq!(sh.pop_output(), None); + + //subscribe should send a subscribe message and local source event + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Subscribe); + + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), Some(Output::SubscribeSource(vec![FeatureControlActor::Controller], remote_node_id))); + assert_eq!(sh.pop_output(), None); + + //unsubscribe should send a unsubscribe message and local source event + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Unsubscribe); + + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Unsubscribe(session_id)))); + assert_eq!(sh.pop_output(), Some(Output::UnsubscribeSource(vec![FeatureControlActor::Controller], remote_node_id))); + assert_eq!(sh.pop_output(), None); + } + + #[test] + fn remote_notify_register_should_not_climb_to_root() { + let node_id = 1; + let session_id = 1234; + let mut sh = SourceHintLogic::new(node_id, session_id); + + let remote = SocketAddr::new([127, 0, 0, 1].into(), 1234); + let remote_node_id = 2; + + //remote register should send event + sh.on_remote( + 0, + remote, + SourceHint::Register { + source: remote_node_id, + to_root: false, + }, + ); + assert_eq!(sh.pop_output(), None); + + sh.on_remote( + 0, + remote, + SourceHint::Unregister { + source: remote_node_id, + to_root: false, + }, + ); + assert_eq!(sh.pop_output(), None); + } + + #[test] + fn remote_register_should_not_resend_same_sender() { + let node_id = 1; + let session_id = 1234; + let mut sh = SourceHintLogic::new(node_id, session_id); + + let remote1 = SocketAddr::new([127, 0, 0, 1].into(), 1234); + let remote1_node_id = 2; + let remote1_session_id = 4321; + + let remote2 = SocketAddr::new([127, 0, 0, 2].into(), 1234); + let remote2_session_id = 4322; + + //remote subscribe should receive a subscribe ok and snapshot of sources + sh.on_remote(0, remote1, SourceHint::Subscribe(remote1_session_id)); + sh.on_remote(0, remote2, SourceHint::Subscribe(remote2_session_id)); + + assert_eq!(sh.pop_output(), Some(Output::SendRemote(Some(remote1), SourceHint::SubscribeOk(remote1_session_id)))); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(Some(remote2), SourceHint::SubscribeOk(remote2_session_id)))); + assert_eq!(sh.pop_output(), None); + + //remote register should send event except sender + sh.on_remote( + 0, + remote1, + SourceHint::Register { + source: remote1_node_id, + to_root: true, + }, + ); + + assert_eq!( + sh.pop_output(), + Some(Output::SendRemote( + None, + SourceHint::Register { + source: remote1_node_id, + to_root: true + } + )) + ); + assert_eq!( + sh.pop_output(), + Some(Output::SendRemote( + Some(remote2), + SourceHint::Register { + source: remote1_node_id, + to_root: false + } + )) + ); + assert_eq!(sh.pop_output(), None); + + //remote register should send event except sender + sh.on_remote( + 0, + remote1, + SourceHint::Unregister { + source: remote1_node_id, + to_root: true, + }, + ); + + assert_eq!( + sh.pop_output(), + Some(Output::SendRemote( + None, + SourceHint::Unregister { + source: remote1_node_id, + to_root: true + } + )) + ); + assert_eq!( + sh.pop_output(), + Some(Output::SendRemote( + Some(remote2), + SourceHint::Unregister { + source: remote1_node_id, + to_root: false + } + )) + ); + assert_eq!(sh.pop_output(), None); + } + + #[test] + fn register_resend_after_tick() { + let node_id = 1; + let session_id = 1234; + let mut sh = SourceHintLogic::new(node_id, session_id); + + let remote = SocketAddr::new([127, 0, 0, 1].into(), 1234); + let remote_node_id = 2; + let remote_session_id = 4321; + + //fake a local source with a remote subscribe + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Register); + sh.on_remote(0, remote, SourceHint::Subscribe(remote_session_id)); + + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Register { source: node_id, to_root: true }))); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(Some(remote), SourceHint::SubscribeOk(remote_session_id)))); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(Some(remote), SourceHint::Sources(vec![node_id])))); + assert_eq!(sh.pop_output(), None); + + sh.on_tick(1000); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Register { source: node_id, to_root: true }))); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(Some(remote), SourceHint::Register { source: node_id, to_root: false }))); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + } + + #[test] + fn subscribe_resend_after_tick() { + let node_id = 1; + let session_id = 1234; + let mut sh = SourceHintLogic::new(node_id, session_id); + + //subscribe should send a subscribe message + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Subscribe); + + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + + sh.on_tick(1000); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + } + + #[test] + fn remote_source_timeout() { + let node_id = 1; + let session_id = 1234; + let mut sh = SourceHintLogic::new(node_id, session_id); + + //subscribe should send a subscribe message and local source event + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Subscribe); + + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + + let remote = SocketAddr::new([127, 0, 0, 1].into(), 1234); + let remote_node_id = 2; + + sh.on_remote( + 0, + remote, + SourceHint::Register { + source: remote_node_id, + to_root: true, + }, + ); + + assert_eq!( + sh.pop_output(), + Some(Output::SendRemote( + None, + SourceHint::Register { + source: remote_node_id, + to_root: true + } + )) + ); + assert_eq!(sh.pop_output(), Some(Output::SubscribeSource(vec![FeatureControlActor::Controller], remote_node_id))); + assert_eq!(sh.pop_output(), None); + + assert_eq!(sh.remote_sources.len(), 1); + sh.on_tick(1000); + assert_eq!(sh.remote_sources.len(), 1); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + + sh.on_tick(TIMEOUT_MS); + assert_eq!(sh.remote_sources.len(), 0); + assert_eq!(sh.pop_output(), Some(Output::UnsubscribeSource(vec![FeatureControlActor::Controller], remote_node_id))); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + } + + #[test] + fn remote_source_notify_timeout() { + let node_id = 1; + let session_id = 1234; + let mut sh = SourceHintLogic::new(node_id, session_id); + + //subscribe should send a subscribe message and local source event + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Subscribe); + + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + + let remote = SocketAddr::new([127, 0, 0, 1].into(), 1234); + let remote_node_id = 2; + + sh.on_remote(0, remote, SourceHint::SubscribeOk(session_id)); + sh.on_remote( + 0, + remote, + SourceHint::Register { + source: remote_node_id, + to_root: false, + }, + ); + + assert_eq!(sh.pop_output(), Some(Output::SubscribeSource(vec![FeatureControlActor::Controller], remote_node_id))); + assert_eq!(sh.pop_output(), None); + + assert_eq!(sh.remote_sources.len(), 1); + sh.on_tick(1000); + assert_eq!(sh.remote_sources.len(), 1); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + + sh.on_tick(TIMEOUT_MS); + assert_eq!(sh.remote_sources.len(), 0); + assert_eq!(sh.pop_output(), Some(Output::UnsubscribeSource(vec![FeatureControlActor::Controller], remote_node_id))); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + } + + #[test] + fn remote_subscriber_timeout() { + let node_id = 1; + let session_id = 1234; + let mut sh = SourceHintLogic::new(node_id, session_id); + + let remote = SocketAddr::new([127, 0, 0, 1].into(), 1234); + let remote_session_id = 4321; + + sh.on_remote(0, remote, SourceHint::Subscribe(remote_session_id)); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(Some(remote), SourceHint::SubscribeOk(remote_session_id)))); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + + assert_eq!(sh.remote_subscribers.len(), 1); + sh.on_tick(1000); + assert_eq!(sh.remote_subscribers.len(), 1); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + + sh.on_tick(TIMEOUT_MS); + assert_eq!(sh.remote_subscribers.len(), 0); + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Unsubscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + } + + #[test] + fn next_hop_changed() { + let node_id = 1; + let session_id = 1234; + let mut sh = SourceHintLogic::new(node_id, session_id); + + //subscribe should send a subscribe message and local source event + sh.on_local(0, FeatureControlActor::Controller, LocalCmd::Subscribe); + + assert_eq!(sh.pop_output(), Some(Output::SendRemote(None, SourceHint::Subscribe(session_id)))); + assert_eq!(sh.pop_output(), None); + + let remote1 = SocketAddr::new([127, 0, 0, 1].into(), 1234); + let remote2 = SocketAddr::new([127, 0, 0, 2].into(), 1234); + + sh.on_remote(0, remote1, SourceHint::SubscribeOk(session_id)); + + let source_id = 100; + sh.on_remote(0, remote1, SourceHint::Register { source: source_id, to_root: false }); + + assert_eq!(sh.pop_output(), Some(Output::SubscribeSource(vec![FeatureControlActor::Controller], source_id))); + assert_eq!(sh.pop_output(), None); + + //now next hop changed to remote2 + sh.on_remote(1000, remote2, SourceHint::SubscribeOk(session_id)); + + //then we will reject any message from remote1 + sh.on_remote(0, remote1, SourceHint::Unregister { source: source_id, to_root: false }); + assert_eq!(sh.pop_output(), None); + + //we only accept from remote2 + sh.on_remote(0, remote2, SourceHint::Unregister { source: source_id, to_root: false }); + assert_eq!(sh.pop_output(), Some(Output::UnsubscribeSource(vec![FeatureControlActor::Controller], source_id))); + assert_eq!(sh.pop_output(), None); + } +} diff --git a/packages/network/src/features/pubsub/mod.rs b/packages/network/src/features/pubsub/mod.rs index 9af145a9..84720647 100644 --- a/packages/network/src/features/pubsub/mod.rs +++ b/packages/network/src/features/pubsub/mod.rs @@ -4,14 +4,14 @@ use atm0s_sdn_identity::NodeId; use crate::base::FeatureControlActor; -use self::msg::{RelayControl, RelayId}; +use self::msg::{RelayControl, RelayId, SourceHint}; mod controller; mod msg; mod worker; pub use controller::PubSubFeature; -pub use msg::ChannelId; +pub use msg::{ChannelId, Feedback}; pub use worker::PubSubFeatureWorker; pub const FEATURE_ID: u8 = 5; @@ -19,9 +19,14 @@ pub const FEATURE_NAME: &str = "pubsub"; #[derive(Debug, Clone, PartialEq, Eq)] pub enum ChannelControl { + SubAuto, + FeedbackAuto(Feedback), + UnsubAuto, SubSource(NodeId), UnsubSource(NodeId), + PubStart, PubData(Vec), + PubStop, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -31,6 +36,7 @@ pub struct Control(pub ChannelId, pub ChannelControl); pub enum ChannelEvent { RouteChanged(NodeId), SourceData(NodeId, Vec), + FeedbackData(Feedback), } #[derive(Debug, Clone, PartialEq, Eq)] @@ -43,6 +49,7 @@ pub enum RelayWorkerControl { SendSubOk(u64, SocketAddr), SendUnsubOk(u64, SocketAddr), SendRouteChanged, + SendFeedback(Feedback, SocketAddr), RouteSetSource(SocketAddr), RouteDelSource(SocketAddr), RouteSetLocal(FeatureControlActor), @@ -66,11 +73,13 @@ impl RelayWorkerControl { #[derive(Debug, Clone)] pub enum ToWorker { - RelayWorkerControl(RelayId, RelayWorkerControl), + RelayControl(RelayId, RelayWorkerControl), + SourceHint(ChannelId, Option, SourceHint), RelayData(RelayId, Vec), } #[derive(Debug, Clone)] pub enum ToController { - RemoteControl(SocketAddr, RelayId, RelayControl), + RelayControl(SocketAddr, RelayId, RelayControl), + SourceHint(SocketAddr, ChannelId, SourceHint), } diff --git a/packages/network/src/features/pubsub/msg.rs b/packages/network/src/features/pubsub/msg.rs index d904e433..42d3b853 100644 --- a/packages/network/src/features/pubsub/msg.rs +++ b/packages/network/src/features/pubsub/msg.rs @@ -12,6 +12,48 @@ simple_pub_type!(ChannelId, u64); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct RelayId(pub ChannelId, pub NodeId); +#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct Feedback { + pub kind: u8, + pub count: u64, + pub max: u64, + pub min: u64, + pub sum: u64, + pub interval_ms: u16, + pub timeout_ms: u16, +} + +impl Feedback { + pub fn simple(kind: u8, value: u64, interval_ms: u16, timeout_ms: u16) -> Self { + Feedback { + kind, + count: 1, + max: value, + min: value, + sum: value, + interval_ms, + timeout_ms, + } + } +} + +///implement add to Feedback +impl std::ops::Add for Feedback { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Feedback { + kind: self.kind, + count: self.count + rhs.count, + max: self.max.max(rhs.max), + min: self.min.min(rhs.min), + sum: self.sum + rhs.sum, + interval_ms: self.interval_ms.min(rhs.interval_ms), + timeout_ms: self.timeout_ms.max(rhs.timeout_ms), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub enum RelayControl { Sub(u64), @@ -19,6 +61,7 @@ pub enum RelayControl { SubOK(u64), UnsubOK(u64), RouteChanged(u64), + Feedback(Feedback), } impl RelayControl { @@ -30,6 +73,38 @@ impl RelayControl { } } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum SourceHint { + /// This is used to notify a source is new or still alive. + /// This message is send to next hop and relayed to all subscribers except sender. + Register { + source: NodeId, + to_root: bool, + }, + /// This is used to notify a source is ended. + /// This message is send to next hop and relayed to all subscribers except sender. + Unregister { + source: NodeId, + to_root: bool, + }, + Subscribe(u64), + SubscribeOk(u64), + Unsubscribe(u64), + UnsubscribeOk(u64), + /// This is used when a new subscriber is added, it is like a snapshot for faster initing state. + Sources(Vec), +} + +impl SourceHint { + pub fn should_create(&self) -> bool { + match self { + SourceHint::Register { .. } => true, + SourceHint::Subscribe(_) => true, + _ => false, + } + } +} + pub enum PubsubMessageError { TransportError(TransportMsgHeaderError), DeserializeError, @@ -38,6 +113,7 @@ pub enum PubsubMessageError { #[derive(Debug, Serialize, Deserialize)] pub enum PubsubMessage { Control(RelayId, RelayControl), + SourceHint(ChannelId, SourceHint), Data(RelayId, Vec), } diff --git a/packages/network/src/features/pubsub/worker.rs b/packages/network/src/features/pubsub/worker.rs index f57f45ee..bd8a7f61 100644 --- a/packages/network/src/features/pubsub/worker.rs +++ b/packages/network/src/features/pubsub/worker.rs @@ -4,9 +4,9 @@ use std::{ }; use atm0s_sdn_identity::ConnId; -use atm0s_sdn_router::RouterTable; +use atm0s_sdn_router::{RouteAction, RouterTable}; -use crate::base::{FeatureControlActor, FeatureWorker, FeatureWorkerContext, FeatureWorkerInput, FeatureWorkerOutput, GenericBuffer, TransportMsgHeader}; +use crate::base::{Buffer, FeatureControlActor, FeatureWorker, FeatureWorkerContext, FeatureWorkerInput, FeatureWorkerOutput, TransportMsgHeader}; use super::{ msg::{PubsubMessage, RelayControl, RelayId}, @@ -50,14 +50,18 @@ impl FeatureWorker for PubSubFeatureWork _conn: ConnId, remote: SocketAddr, _header: TransportMsgHeader, - buf: GenericBuffer<'a>, + buf: Buffer<'a>, ) -> Option> { log::debug!("[PubSubWorker] on_network_raw from {}", remote); let msg = PubsubMessage::try_from(&buf as &[u8]).ok()?; match msg { PubsubMessage::Control(relay_id, control) => { - log::debug!("[PubSubWorker] received PubsubMessage::Control({:?}, {:?})", relay_id, control); - Some(FeatureWorkerOutput::ToController(ToController::RemoteControl(remote, relay_id, control))) + log::debug!("[PubSubWorker] received PubsubMessage::RelayControl({:?}, {:?})", relay_id, control); + Some(FeatureWorkerOutput::ToController(ToController::RelayControl(remote, relay_id, control))) + } + PubsubMessage::SourceHint(channel, control) => { + log::debug!("[PubSubWorker] received PubsubMessage::SourceHint({:?}, {:?})", channel, control); + Some(FeatureWorkerOutput::ToController(ToController::SourceHint(remote, channel, control))) } PubsubMessage::Data(relay_id, data) => { log::debug!("[PubSubWorker] received PubsubMessage::Data({:?}, size {})", relay_id, data.len()); @@ -73,7 +77,7 @@ impl FeatureWorker for PubSubFeatureWork let control = PubsubMessage::Data(relay_id, data); let size: usize = control.write_to(&mut self.buf)?; //TODO avoid copy - Some(FeatureWorkerOutput::RawBroadcast2(relay.remotes.clone(), GenericBuffer::from(self.buf[0..size].to_vec()))) + Some(FeatureWorkerOutput::RawBroadcast2(relay.remotes.clone(), Buffer::from(self.buf[0..size].to_vec()))) } else { //TODO avoid push temp to queue self.queue.pop_front() @@ -88,7 +92,7 @@ impl FeatureWorker for PubSubFeatureWork fn on_input<'a>(&mut self, ctx: &mut FeatureWorkerContext, _now: u64, input: FeatureWorkerInput<'a, Control, ToWorker>) -> Option> { match input { - FeatureWorkerInput::FromController(_, ToWorker::RelayWorkerControl(relay_id, control)) => match control { + FeatureWorkerInput::FromController(_, ToWorker::RelayControl(relay_id, control)) => match control { RelayWorkerControl::SendSub(uuid, remote) => { let dest = if let Some(remote) = remote { log::debug!("[PubsubWorker] SendSub for {:?} to {}", relay_id, remote); @@ -103,28 +107,35 @@ impl FeatureWorker for PubSubFeatureWork let control = PubsubMessage::Control(relay_id, RelayControl::Sub(uuid)); let size: usize = control.write_to(&mut self.buf)?; //TODO avoid copy - Some(FeatureWorkerOutput::RawDirect2(dest, GenericBuffer::from(self.buf[0..size].to_vec()))) + Some(FeatureWorkerOutput::RawDirect2(dest, Buffer::from(self.buf[0..size].to_vec()))) + } + RelayWorkerControl::SendFeedback(fb, remote) => { + log::debug!("[PubsubWorker] SendFeedback for {:?} to {:?}", relay_id, remote); + let control = PubsubMessage::Control(relay_id, RelayControl::Feedback(fb)); + let size: usize = control.write_to(&mut self.buf)?; + //TODO avoid copy + Some(FeatureWorkerOutput::RawDirect2(remote, Buffer::from(self.buf[0..size].to_vec()))) } RelayWorkerControl::SendUnsub(uuid, remote) => { log::debug!("[PubsubWorker] SendUnsub for {:?} to {:?}", relay_id, remote); let control = PubsubMessage::Control(relay_id, RelayControl::Unsub(uuid)); let size: usize = control.write_to(&mut self.buf)?; //TODO avoid copy - Some(FeatureWorkerOutput::RawDirect2(remote, GenericBuffer::from(self.buf[0..size].to_vec()))) + Some(FeatureWorkerOutput::RawDirect2(remote, Buffer::from(self.buf[0..size].to_vec()))) } RelayWorkerControl::SendSubOk(uuid, remote) => { log::debug!("[PubsubWorker] SendSubOk for {:?} to {:?}", relay_id, remote); let control = PubsubMessage::Control(relay_id, RelayControl::SubOK(uuid)); let size: usize = control.write_to(&mut self.buf)?; //TODO avoid copy - Some(FeatureWorkerOutput::RawDirect2(remote, GenericBuffer::from(self.buf[0..size].to_vec()))) + Some(FeatureWorkerOutput::RawDirect2(remote, Buffer::from(self.buf[0..size].to_vec()))) } RelayWorkerControl::SendUnsubOk(uuid, remote) => { log::debug!("[PubsubWorker] SendUnsubOk for {:?} to {:?}", relay_id, remote); let control = PubsubMessage::Control(relay_id, RelayControl::UnsubOK(uuid)); let size: usize = control.write_to(&mut self.buf)?; //TODO avoid copy - Some(FeatureWorkerOutput::RawDirect2(remote, GenericBuffer::from(self.buf[0..size].to_vec()))) + Some(FeatureWorkerOutput::RawDirect2(remote, Buffer::from(self.buf[0..size].to_vec()))) } RelayWorkerControl::SendRouteChanged => { let relay = self.relays.get(&relay_id)?; @@ -133,7 +144,7 @@ impl FeatureWorker for PubSubFeatureWork let control = PubsubMessage::Control(relay_id, RelayControl::RouteChanged(*uuid)); let size: usize = control.write_to(&mut self.buf)?; //TODO avoid copy - self.queue.push_back(FeatureWorkerOutput::RawDirect2(*addr, GenericBuffer::from(self.buf[0..size].to_vec()))); + self.queue.push_back(FeatureWorkerOutput::RawDirect2(*addr, Buffer::from(self.buf[0..size].to_vec()))); } self.queue.pop_front() } @@ -216,6 +227,23 @@ impl FeatureWorker for PubSubFeatureWork None } }, + FeatureWorkerInput::FromController(_, ToWorker::SourceHint(channel, remote, data)) => { + if let Some(remote) = remote { + let control = PubsubMessage::SourceHint(channel, data); + let size = control.write_to(&mut self.buf)?; + Some(FeatureWorkerOutput::RawDirect2(remote, Buffer::from(self.buf[0..size].to_vec()))) + } else { + let next = ctx.router.path_to_key(*channel as u32); + match next { + RouteAction::Next(remote) => { + let control = PubsubMessage::SourceHint(channel, data); + let size = control.write_to(&mut self.buf)?; + Some(FeatureWorkerOutput::RawDirect2(remote, Buffer::from(self.buf[0..size].to_vec()))) + } + _ => None, + } + } + } FeatureWorkerInput::FromController(_, ToWorker::RelayData(relay_id, data)) => { let relay = self.relays.get(&relay_id)?; if relay.remotes.is_empty() { @@ -225,7 +253,7 @@ impl FeatureWorker for PubSubFeatureWork let control = PubsubMessage::Data(relay_id, data); let size: usize = control.write_to(&mut self.buf)?; //TODO avoid copy - Some(FeatureWorkerOutput::RawBroadcast2(relay.remotes.clone(), GenericBuffer::from(self.buf[0..size].to_vec()))) + Some(FeatureWorkerOutput::RawBroadcast2(relay.remotes.clone(), Buffer::from(self.buf[0..size].to_vec()))) } FeatureWorkerInput::Control(actor, control) => match control { Control(channel, ChannelControl::PubData(data)) => { @@ -241,7 +269,7 @@ impl FeatureWorker for PubSubFeatureWork let control = PubsubMessage::Data(relay_id, data); let size: usize = control.write_to(&mut self.buf)?; //TODO avoid copy - Some(FeatureWorkerOutput::RawBroadcast2(relay.remotes.clone(), GenericBuffer::from(self.buf[0..size].to_vec()))) + Some(FeatureWorkerOutput::RawBroadcast2(relay.remotes.clone(), Buffer::from(self.buf[0..size].to_vec()))) } else { //TODO avoid push temp to queue self.queue.pop_front() diff --git a/packages/network/src/features/router_sync.rs b/packages/network/src/features/router_sync.rs index 4da97a19..1d07c4c4 100644 --- a/packages/network/src/features/router_sync.rs +++ b/packages/network/src/features/router_sync.rs @@ -166,6 +166,7 @@ impl FeatureWorker for RouterSyncFeature log::warn!("No handler for local message in {}", FEATURE_NAME); None } + #[cfg(feature = "vpn")] FeatureWorkerInput::TunPkt(_buf) => { log::warn!("No handler for tun packet in {}", FEATURE_NAME); None diff --git a/packages/network/src/features/socket.rs b/packages/network/src/features/socket.rs index 788b560d..ec1701b9 100644 --- a/packages/network/src/features/socket.rs +++ b/packages/network/src/features/socket.rs @@ -4,7 +4,7 @@ use atm0s_sdn_identity::NodeId; use atm0s_sdn_router::RouteRule; use crate::base::{ - Feature, FeatureContext, FeatureControlActor, FeatureInput, FeatureOutput, FeatureSharedInput, FeatureWorker, FeatureWorkerContext, FeatureWorkerInput, FeatureWorkerOutput, GenericBuffer, + Buffer, Feature, FeatureContext, FeatureControlActor, FeatureInput, FeatureOutput, FeatureSharedInput, FeatureWorker, FeatureWorkerContext, FeatureWorkerInput, FeatureWorkerOutput, NetOutgoingMeta, TransportMsgHeader, Ttl, }; @@ -228,7 +228,7 @@ impl FeatureWorker for SocketFeatureWork _conn: atm0s_sdn_identity::ConnId, _remote: std::net::SocketAddr, header: TransportMsgHeader, - buf: GenericBuffer<'a>, + buf: Buffer<'a>, ) -> Option> { self.process_incoming(header.from_node?, &(&buf)[header.serialize_size()..], header.meta) } diff --git a/packages/network/src/features/vpn.rs b/packages/network/src/features/vpn.rs index 05c2d387..07841def 100644 --- a/packages/network/src/features/vpn.rs +++ b/packages/network/src/features/vpn.rs @@ -1,7 +1,11 @@ +#[cfg(feature = "vpn")] +use crate::base::{BufferMut, TransportMsg}; +#[cfg(feature = "vpn")] use atm0s_sdn_identity::{NodeId, NodeIdType}; +#[cfg(feature = "vpn")] use atm0s_sdn_router::{RouteAction, RouteRule, RouterTable}; -use crate::base::{Feature, FeatureContext, FeatureInput, FeatureOutput, FeatureWorker, FeatureWorkerContext, FeatureWorkerInput, FeatureWorkerOutput, GenericBuffer, GenericBufferMut, TransportMsg}; +use crate::base::{Buffer, Feature, FeatureContext, FeatureInput, FeatureOutput, FeatureWorker, FeatureWorkerContext, FeatureWorkerInput, FeatureWorkerOutput}; pub const FEATURE_ID: u8 = 3; pub const FEATURE_NAME: &str = "vpn"; @@ -34,7 +38,8 @@ impl Feature for VpnFeature { pub struct VpnFeatureWorker; impl VpnFeatureWorker { - fn process_tun<'a>(&mut self, ctx: &FeatureWorkerContext, mut pkt: GenericBufferMut<'a>) -> Option> { + #[cfg(feature = "vpn")] + fn process_tun<'a>(&mut self, ctx: &FeatureWorkerContext, mut pkt: BufferMut<'a>) -> Option> { #[cfg(any(target_os = "macos", target_os = "ios"))] let to_ip = &pkt[20..24]; #[cfg(any(target_os = "linux", target_os = "android"))] @@ -56,14 +61,20 @@ impl VpnFeatureWorker { } } - fn process_udp<'a>(&self, _ctx: &FeatureWorkerContext, pkt: GenericBuffer<'a>) -> Option> { - Some(FeatureWorkerOutput::TunPkt(pkt)) + fn process_udp<'a>(&self, _ctx: &FeatureWorkerContext, pkt: Buffer<'a>) -> Option> { + #[cfg(feature = "vpn")] + { + Some(FeatureWorkerOutput::TunPkt(pkt)) + } + #[cfg(not(feature = "vpn"))] + None } } impl FeatureWorker for VpnFeatureWorker { fn on_input<'a>(&mut self, ctx: &mut FeatureWorkerContext, _now: u64, input: FeatureWorkerInput<'a, Control, ToWorker>) -> Option> { match input { + #[cfg(feature = "vpn")] FeatureWorkerInput::TunPkt(pkt) => self.process_tun(ctx, pkt), FeatureWorkerInput::Network(_conn, _header, pkt) => self.process_udp(ctx, pkt), _ => None, diff --git a/packages/network/src/lib.rs b/packages/network/src/lib.rs index a9029192..6d7b782d 100644 --- a/packages/network/src/lib.rs +++ b/packages/network/src/lib.rs @@ -2,7 +2,7 @@ use std::net::SocketAddr; use atm0s_sdn_identity::{ConnId, NodeAddr, NodeId}; use atm0s_sdn_router::RouteRule; -use base::{FeatureControlActor, NeighboursControl, NetIncomingMeta, NetOutgoingMeta, SecureContext, ServiceId}; +use base::{FeatureControlActor, NeighboursControl, NetIncomingMeta, NetOutgoingMeta, SecureContext, ServiceControlActor, ServiceId}; pub use convert_enum; use features::{Features, FeaturesControl, FeaturesEvent, FeaturesToController, FeaturesToWorker}; @@ -14,6 +14,7 @@ pub mod data_plane; pub mod features; pub mod secure; pub mod services; +pub mod worker; #[derive(Debug, Clone, convert_enum::From)] pub enum ExtIn { @@ -27,22 +28,25 @@ pub enum ExtIn { #[derive(Debug, Clone, PartialEq, Eq)] pub enum ExtOut { FeaturesEvent(FeaturesEvent), - ServicesEvent(ServicesEvent), + ServicesEvent(ServiceId, ServicesEvent), } #[derive(Debug, Clone)] -pub enum LogicControl { +pub enum LogicControl { Feature(FeaturesToController), Service(ServiceId, TC), NetNeighbour(SocketAddr, NeighboursControl), NetRemote(Features, ConnId, NetIncomingMeta, Vec), NetLocal(Features, NetIncomingMeta, Vec), FeaturesControl(FeatureControlActor, FeaturesControl), + ServicesControl(ServiceControlActor, ServiceId, SC), ServiceEvent(ServiceId, FeaturesEvent), + ExtFeaturesEvent(FeaturesEvent), + ExtServicesEvent(ServiceId, SE), } #[derive(Debug, Clone)] -pub enum LogicEvent { +pub enum LogicEvent { NetNeighbour(SocketAddr, NeighboursControl), NetDirect(Features, SocketAddr, ConnId, NetOutgoingMeta, Vec), NetRoute(Features, RouteRule, NetOutgoingMeta, Vec), @@ -52,16 +56,31 @@ pub enum LogicEvent { /// first bool is flag for broadcast or not Feature(bool, FeaturesToWorker), Service(ServiceId, TW), + /// first u16 is worker id + ExtFeaturesEvent(u16, FeaturesEvent), + /// first u16 is worker id + ExtServicesEvent(u16, ServiceId, SE), } -impl LogicEvent { - pub fn is_broadcast(&self) -> bool { +pub enum LogicEventDest { + Broadcast, + Any, + Worker(u16), +} + +impl LogicEvent { + pub fn dest(&self) -> LogicEventDest { match self { - LogicEvent::Pin(..) => true, - LogicEvent::UnPin(..) => true, - LogicEvent::Feature(is_broadcast, ..) => *is_broadcast, - LogicEvent::Service(..) => true, - _ => false, + LogicEvent::Pin(..) => LogicEventDest::Broadcast, + LogicEvent::UnPin(..) => LogicEventDest::Broadcast, + LogicEvent::Service(..) => LogicEventDest::Broadcast, + LogicEvent::Feature(true, ..) => LogicEventDest::Broadcast, + LogicEvent::Feature(false, ..) => LogicEventDest::Any, + LogicEvent::NetNeighbour(_, _) => LogicEventDest::Any, + LogicEvent::NetDirect(_, _, _, _, _) => LogicEventDest::Any, + LogicEvent::NetRoute(_, _, _, _) => LogicEventDest::Any, + LogicEvent::ExtFeaturesEvent(worker, _) => LogicEventDest::Worker(*worker), + LogicEvent::ExtServicesEvent(worker, _, _) => LogicEventDest::Worker(*worker), } } } diff --git a/packages/network/src/secure/encryption/x25519_dalek_aes.rs b/packages/network/src/secure/encryption/x25519_dalek_aes.rs index e02dafb0..925fa43c 100644 --- a/packages/network/src/secure/encryption/x25519_dalek_aes.rs +++ b/packages/network/src/secure/encryption/x25519_dalek_aes.rs @@ -10,7 +10,7 @@ use aes_gcm::{ use rand::rngs::OsRng; use x25519_dalek::{EphemeralSecret, PublicKey}; -use crate::base::{DecryptionError, Decryptor, EncryptionError, Encryptor, GenericBufferMut, HandshakeBuilder, HandshakeError, HandshakeRequester, HandshakeResponder}; +use crate::base::{BufferMut, DecryptionError, Decryptor, EncryptionError, Encryptor, HandshakeBuilder, HandshakeError, HandshakeRequester, HandshakeResponder}; const MSG_TIMEOUT_MS: u64 = 5000; // after 5 seconds message is considered expired @@ -93,11 +93,11 @@ impl EncryptorXDA { } impl Encryptor for EncryptorXDA { - fn encrypt<'a>(&mut self, now_ms: u64, buf: &mut GenericBufferMut<'a>) -> Result<(), EncryptionError> { + fn encrypt<'a>(&mut self, now_ms: u64, buf: &mut BufferMut<'a>) -> Result<(), EncryptionError> { let mut nonce = Aes256Gcm::generate_nonce(&mut OsRng); nonce[4..].copy_from_slice(&now_ms.to_be_bytes()); - self.aes.encrypt_in_place(&nonce, &[], buf).map_err(|_| EncryptionError::EncryptFailed)?; - buf.extend_from_slice(&nonce).map_err(|_| EncryptionError::EncryptFailed)?; + self.aes.encrypt_in_place(&nonce, &[], &mut BufferMut2(buf)).map_err(|_| EncryptionError::EncryptFailed)?; + buf.push_back(&nonce); Ok(()) } @@ -131,7 +131,7 @@ impl Debug for DecryptorXDA { } impl Decryptor for DecryptorXDA { - fn decrypt<'a>(&mut self, now_ms: u64, data: &mut GenericBufferMut<'a>) -> Result<(), DecryptionError> { + fn decrypt<'a>(&mut self, now_ms: u64, data: &mut BufferMut<'a>) -> Result<(), DecryptionError> { let nonce = if let Some(nonce) = data.pop_back(12) { nonce.to_vec() } else { @@ -142,7 +142,7 @@ impl Decryptor for DecryptorXDA { return Err(DecryptionError::TooOld); } let nonce = Nonce::from_slice(&nonce); - self.aes.decrypt_in_place(nonce, &[], data).map_err(|_| DecryptionError::DecryptError)?; + self.aes.decrypt_in_place(nonce, &[], &mut BufferMut2(data)).map_err(|_| DecryptionError::DecryptError)?; Ok(()) } @@ -154,34 +154,36 @@ impl Decryptor for DecryptorXDA { } } -impl<'a> Buffer for GenericBufferMut<'a> { +struct BufferMut2<'a, 'b>(&'a mut BufferMut<'b>); + +impl<'a, 'b> Buffer for BufferMut2<'a, 'b> { fn extend_from_slice(&mut self, other: &[u8]) -> aes_gcm::aead::Result<()> { - self.push_back(other); + self.0.push_back(other); Ok(()) } fn truncate(&mut self, len: usize) { - GenericBufferMut::truncate(self, len).expect("Should truncate ok"); + BufferMut::truncate(self.0, len).expect("Should truncate ok"); } fn len(&self) -> usize { - self.deref().len() + self.0.deref().len() } fn is_empty(&self) -> bool { - self.deref().is_empty() + self.0.deref().is_empty() } } -impl<'a> AsRef<[u8]> for GenericBufferMut<'a> { +impl<'a, 'b> AsRef<[u8]> for BufferMut2<'a, 'b> { fn as_ref(&self) -> &[u8] { - self.deref() + self.0.deref() } } -impl<'a> AsMut<[u8]> for GenericBufferMut<'a> { +impl<'a, 'b> AsMut<[u8]> for BufferMut2<'a, 'b> { fn as_mut(&mut self) -> &mut [u8] { - self.deref_mut() + self.0.deref_mut() } } @@ -189,7 +191,7 @@ impl<'a> AsMut<[u8]> for GenericBufferMut<'a> { mod tests { use std::ops::Deref; - use crate::base::{GenericBufferMut, HandshakeRequester, HandshakeResponder}; + use crate::base::{BufferMut, HandshakeRequester, HandshakeResponder}; use super::{HandshakeRequesterXDA, HandshakeResponderXDA}; @@ -203,13 +205,13 @@ mod tests { let msg = [1, 2, 3, 4]; - let mut buf1 = GenericBufferMut::build(&msg, 0, 1000); + let mut buf1 = BufferMut::build(&msg, 0, 1000); s_encrypt.encrypt(123, &mut buf1).expect("Should ok"); assert_ne!(buf1.len(), msg.len()); c_decrypt.decrypt(124, &mut buf1).expect("Should ok"); assert_eq!(buf1.deref(), msg); - let mut buf2 = GenericBufferMut::build(&msg, 0, 1000); + let mut buf2 = BufferMut::build(&msg, 0, 1000); c_encrypt.encrypt(123, &mut buf2).expect("Should ok"); assert_ne!(buf2.len(), msg.len()); s_decrypt.decrypt(124, &mut buf2).expect("Should ok"); @@ -224,11 +226,11 @@ mod tests { let (mut s_encrypt, _s_decrypt, res) = server.process_public_request(client.create_public_request().expect("").as_slice()).expect("Should ok"); let (_c_encrypt, mut c_decrypt) = client.process_public_response(res.as_slice()).expect("Should ok"); - let mut buf1 = GenericBufferMut::build(&[0, 0, 0, 1], 0, 1000); + let mut buf1 = BufferMut::build(&[0, 0, 0, 1], 0, 1000); s_encrypt.encrypt(123, &mut buf1).expect("Should ok"); - let mut buf2 = GenericBufferMut::build(&[0, 0, 0, 2], 0, 1000); + let mut buf2 = BufferMut::build(&[0, 0, 0, 2], 0, 1000); s_encrypt.encrypt(123, &mut buf2).expect("Should ok"); - let mut buf3 = GenericBufferMut::build(&[0, 0, 0, 3], 0, 1000); + let mut buf3 = BufferMut::build(&[0, 0, 0, 3], 0, 1000); s_encrypt.encrypt(123, &mut buf3).expect("Should ok"); c_decrypt.decrypt(123, &mut buf1).expect("Should ok"); @@ -265,7 +267,7 @@ mod tests { for i in 0..1024 { let value: u32 = i; let msg = value.to_be_bytes(); - let mut buf = GenericBufferMut::build(&msg, 0, 1000); + let mut buf = BufferMut::build(&msg, 0, 1000); s_enc_threads[i as usize % ENC_THREADS].encrypt(i as u64, &mut buf).expect("Should ok"); c_dec_threads[i as usize % DEC_THREADS].decrypt(i as u64, &mut buf).expect("Should ok"); assert_eq!(buf.deref(), msg); diff --git a/packages/network/src/services/manual_discovery.rs b/packages/network/src/services/manual_discovery.rs index 96366304..8230751d 100644 --- a/packages/network/src/services/manual_discovery.rs +++ b/packages/network/src/services/manual_discovery.rs @@ -159,7 +159,7 @@ impl Service ServiceWorker for ManualDiscoveryServiceWorker { +impl ServiceWorker for ManualDiscoveryServiceWorker { fn service_id(&self) -> u8 { SERVICE_ID } @@ -206,7 +206,7 @@ where Box::new(ManualDiscoveryService::new(self.node_addr.clone(), self.local_tags.clone(), self.connect_tags.clone())) } - fn create_worker(&self) -> Box> { + fn create_worker(&self) -> Box> { Box::new(ManualDiscoveryServiceWorker {}) } } diff --git a/packages/network/src/services/visualization.rs b/packages/network/src/services/visualization.rs index 55480907..bc9f4983 100644 --- a/packages/network/src/services/visualization.rs +++ b/packages/network/src/services/visualization.rs @@ -207,7 +207,7 @@ where pub struct VisualizationServiceWorker {} -impl ServiceWorker for VisualizationServiceWorker { +impl ServiceWorker for VisualizationServiceWorker { fn service_id(&self) -> u8 { SERVICE_ID } @@ -255,7 +255,7 @@ where Box::new(VisualizationService::new()) } - fn create_worker(&self) -> Box> { + fn create_worker(&self) -> Box> { Box::new(VisualizationServiceWorker {}) } } diff --git a/packages/network/src/worker.rs b/packages/network/src/worker.rs new file mode 100644 index 00000000..c533c03a --- /dev/null +++ b/packages/network/src/worker.rs @@ -0,0 +1,217 @@ +use std::fmt::Debug; + +use atm0s_sdn_identity::NodeId; +use sans_io_runtime::TaskSwitcher; + +use crate::{ + controller_plane::{self, ControllerPlane, ControllerPlaneCfg}, + data_plane::{self, CrossWorker, DataPlane, DataPlaneCfg, NetInput, NetOutput}, + ExtIn, ExtOut, LogicControl, LogicEvent, LogicEventDest, +}; + +#[derive(Debug, Clone)] +pub enum SdnWorkerBusEvent { + Control(LogicControl), + Workers(LogicEvent), + Worker(u16, CrossWorker), +} + +pub enum SdnWorkerInput<'a, SC, SE, TC, TW> { + Ext(ExtIn), + ExtWorker(ExtIn), + Net(NetInput<'a>), + Bus(SdnWorkerBusEvent), + ShutdownRequest, +} + +#[derive(Debug)] +pub enum SdnWorkerOutput<'a, SC, SE, TC, TW> { + Ext(ExtOut), + ExtWorker(ExtOut), + Net(NetOutput<'a>), + Bus(SdnWorkerBusEvent), + ShutdownResponse, + Continue, +} + +pub struct SdnWorkerCfg { + pub node_id: NodeId, + pub tick_ms: u64, + pub controller: Option>, + pub data: DataPlaneCfg, +} + +pub struct SdnWorker { + tick_ms: u64, + controller: Option>, + data: DataPlane, + data_shutdown: bool, + switcher: TaskSwitcher, + last_tick: Option, +} + +impl SdnWorker { + pub fn new(cfg: SdnWorkerCfg) -> Self { + Self { + tick_ms: cfg.tick_ms, + controller: cfg.controller.map(|controller| ControllerPlane::new(cfg.node_id, controller)), + data: DataPlane::new(cfg.node_id, cfg.data), + data_shutdown: false, + switcher: TaskSwitcher::new(2), + last_tick: None, + } + } + + pub fn tasks(&self) -> usize { + let mut tasks = 0; + if self.controller.is_some() { + tasks += 1; + } + if !self.data_shutdown { + tasks += 1; + } + tasks + } + + pub fn on_tick<'a>(&mut self, now_ms: u64) -> Option> { + if let Some(last_tick) = self.last_tick { + if now_ms < last_tick + self.tick_ms { + return None; + } + } + self.last_tick = Some(now_ms); + + self.switcher.queue_flag_all(); + self.data.on_tick(now_ms); + if let Some(controller) = &mut self.controller { + controller.on_tick(now_ms); + if let Some(out) = controller.pop_output(now_ms) { + return Some(self.process_controller_out(now_ms, out)); + } + } + let out = self.data.pop_output(now_ms)?; + Some(self.process_data_out(now_ms, out)) + } + + pub fn on_event<'a>(&mut self, now_ms: u64, input: SdnWorkerInput<'a, SC, SE, TC, TW>) -> Option> { + match input { + SdnWorkerInput::Ext(ext) => { + let controller: &mut ControllerPlane = self.controller.as_mut().expect("Should have controller"); + controller.on_event(now_ms, controller_plane::Input::Ext(ext)); + let out = controller.pop_output(now_ms)?; + Some(self.process_controller_out(now_ms, out)) + } + SdnWorkerInput::ExtWorker(ext) => { + let out = self.data.on_event(now_ms, data_plane::Input::Ext(ext))?; + Some(self.process_data_out(now_ms, out)) + } + SdnWorkerInput::Net(net) => { + let out = self.data.on_event(now_ms, data_plane::Input::Net(net))?; + Some(self.process_data_out(now_ms, out)) + } + SdnWorkerInput::Bus(bus) => match bus { + SdnWorkerBusEvent::Control(control) => { + let controller = self.controller.as_mut().expect("Should have controller"); + controller.on_event(now_ms, controller_plane::Input::Control(control)); + let out = controller.pop_output(now_ms)?; + Some(self.process_controller_out(now_ms, out)) + } + SdnWorkerBusEvent::Workers(event) => { + let out = self.data.on_event(now_ms, data_plane::Input::Event(event))?; + Some(self.process_data_out(now_ms, out)) + } + SdnWorkerBusEvent::Worker(_, cross) => { + let out = self.data.on_event(now_ms, data_plane::Input::Worker(cross))?; + Some(self.process_data_out(now_ms, out)) + } + }, + SdnWorkerInput::ShutdownRequest => { + self.switcher.queue_flag_all(); + if let Some(controller) = &mut self.controller { + controller.on_event(now_ms, controller_plane::Input::ShutdownRequest); + } + if let Some(out) = self.data.on_event(now_ms, data_plane::Input::ShutdownRequest) { + Some(self.process_data_out(now_ms, out)) + } else if let Some(controller) = &mut self.controller { + let out = controller.pop_output(now_ms)?; + Some(self.process_controller_out(now_ms, out)) + } else { + None + } + } + } + } + + pub fn pop_output<'a>(&mut self, now_ms: u64) -> Option> { + while let Some(current) = self.switcher.queue_current() { + match current { + 0 => { + if let Some(controller) = &mut self.controller { + if let Some(out) = self.switcher.queue_process(controller.pop_output(now_ms)) { + return Some(self.process_controller_out(now_ms, out)); + } + } else { + self.switcher.queue_process(None::<()>); + } + } + 1 => { + if let Some(out) = self.switcher.queue_process(self.data.pop_output(now_ms)) { + return Some(self.process_data_out(now_ms, out)); + } + } + _ => panic!("unknown task type"), + } + } + None + } +} + +impl SdnWorker { + fn process_controller_out<'a>(&mut self, now_ms: u64, out: controller_plane::Output) -> SdnWorkerOutput<'a, SC, SE, TC, TW> { + self.switcher.queue_flag_task(0); + match out { + controller_plane::Output::Ext(out) => SdnWorkerOutput::Ext(out), + controller_plane::Output::Event(event) => match event.dest() { + LogicEventDest::Broadcast | LogicEventDest::Worker(_) => SdnWorkerOutput::Bus(SdnWorkerBusEvent::Workers(event)), + LogicEventDest::Any => { + if let Some(out) = self.data.on_event(now_ms, data_plane::Input::Event(event)) { + self.process_data_out(now_ms, out) + } else { + SdnWorkerOutput::Continue + } + } + }, + controller_plane::Output::ShutdownSuccess => { + self.controller = None; + SdnWorkerOutput::Continue + } + } + } + + fn process_data_out<'a>(&mut self, now_ms: u64, out: data_plane::Output<'a, SC, SE, TC>) -> SdnWorkerOutput<'a, SC, SE, TC, TW> { + self.switcher.queue_flag_task(1); + match out { + data_plane::Output::Ext(ext) => SdnWorkerOutput::ExtWorker(ext), + data_plane::Output::Net(out) => SdnWorkerOutput::Net(out), + data_plane::Output::Control(control) => { + if let Some(controller) = &mut self.controller { + log::debug!("Send control to controller {:?}", control); + controller.on_event(now_ms, controller_plane::Input::Control(control)); + if let Some(out) = controller.pop_output(now_ms) { + self.process_controller_out(now_ms, out) + } else { + SdnWorkerOutput::Continue + } + } else { + SdnWorkerOutput::Bus(SdnWorkerBusEvent::Control(control)) + } + } + data_plane::Output::Worker(index, cross) => SdnWorkerOutput::Bus(SdnWorkerBusEvent::Worker(index, cross)), + data_plane::Output::ShutdownResponse => { + self.data_shutdown = true; + SdnWorkerOutput::Continue + } + data_plane::Output::Continue => SdnWorkerOutput::Continue, + } + } +} diff --git a/packages/network/tests/feature_alias.rs b/packages/network/tests/feature_alias.rs index b74da95e..0a4bbf50 100644 --- a/packages/network/tests/feature_alias.rs +++ b/packages/network/tests/feature_alias.rs @@ -36,7 +36,7 @@ impl Service for MockService { struct MockServiceWorker; -impl ServiceWorker for MockServiceWorker { +impl ServiceWorker for MockServiceWorker { fn service_id(&self) -> u8 { 0 } @@ -61,7 +61,7 @@ impl ServiceBuilder for MockServ Box::new(MockService) } - fn create_worker(&self) -> Box> { + fn create_worker(&self) -> Box> { Box::new(MockServiceWorker) } } diff --git a/packages/network/tests/feature_pubsub.rs b/packages/network/tests/feature_pubsub.rs index 8aa91a2c..894920a3 100644 --- a/packages/network/tests/feature_pubsub.rs +++ b/packages/network/tests/feature_pubsub.rs @@ -1,6 +1,6 @@ use atm0s_sdn_network::{ features::{ - pubsub::{ChannelControl, ChannelEvent, ChannelId, Control, Event}, + pubsub::{ChannelControl, ChannelEvent, ChannelId, Control, Event, Feedback}, FeaturesControl, FeaturesEvent, }, ExtIn, ExtOut, @@ -19,7 +19,7 @@ fn event(event: Event) -> ExtOut<()> { } #[test] -fn feature_pubsub_single_node() { +fn feature_pubsub_manual_single_node() { let node_id = 1; let mut sim = NetworkSimulator::<(), (), (), ()>::new(0); sim.add_node(TestNode::new(node_id, 1234, vec![])); @@ -37,7 +37,75 @@ fn feature_pubsub_single_node() { } #[test] -fn feature_pubsub_two_nodes() { +fn feature_pubsub_auto_single_node() { + let node_id = 1; + let mut sim = NetworkSimulator::<(), (), (), ()>::new(0); + sim.add_node(TestNode::new(node_id, 1234, vec![])); + + sim.process(100); + + let channel = ChannelId(1000); + let value = vec![1, 2, 3, 4]; + + sim.control(node_id, control(Control(channel, ChannelControl::PubStart))); + sim.control(node_id, control(Control(channel, ChannelControl::SubAuto))); + sim.process(1); + sim.control(node_id, control(Control(channel, ChannelControl::PubData(value.clone())))); + sim.process(1); + assert_eq!(sim.pop_res(), Some((node_id, event(Event(channel, ChannelEvent::SourceData(node_id, value.clone())))))); + assert_eq!(sim.pop_res(), None); + + log::info!("Simulate feedback source now"); + sim.control(node_id, control(Control(channel, ChannelControl::FeedbackAuto(Feedback::simple(0, 10, 1000, 2000))))); + sim.process(2000); //after that tick feedback will timeout + assert_eq!(sim.pop_res(), Some((node_id, event(Event(channel, ChannelEvent::FeedbackData(Feedback::simple(0, 10, 1000, 2000))))))); + assert_eq!(sim.pop_res(), None); + + sim.control(node_id, control(Control(channel, ChannelControl::UnsubAuto))); + sim.process(1); + sim.control(node_id, control(Control(channel, ChannelControl::PubData(value.clone())))); + sim.process(1); + assert_eq!(sim.pop_res(), None); +} + +#[test] +fn feature_pubsub_auto_single_node_worker() { + let node_id = 1; + let mut sim = NetworkSimulator::<(), (), (), ()>::new(0); + sim.enable_log(log::LevelFilter::Debug); + sim.add_node(TestNode::new(node_id, 1234, vec![])); + + sim.process(100); + + let channel = ChannelId(1000); + let value = vec![1, 2, 3, 4]; + + sim.control_worker(node_id, control(Control(channel, ChannelControl::PubStart))); + sim.control_worker(node_id, control(Control(channel, ChannelControl::SubAuto))); + sim.process(1); + sim.control_worker(node_id, control(Control(channel, ChannelControl::PubData(value.clone())))); + sim.process(1); + assert_eq!(sim.pop_res_worker(), Some((node_id, event(Event(channel, ChannelEvent::SourceData(node_id, value.clone())))))); + assert_eq!(sim.pop_res_worker(), None); + + log::info!("Simulate feedback source now"); + sim.control_worker(node_id, control(Control(channel, ChannelControl::FeedbackAuto(Feedback::simple(0, 10, 1000, 2000))))); + sim.process(2000); //after that tick feedback will timeout + assert_eq!( + sim.pop_res_worker(), + Some((node_id, event(Event(channel, ChannelEvent::FeedbackData(Feedback::simple(0, 10, 1000, 2000)))))) + ); + assert_eq!(sim.pop_res_worker(), None); + + sim.control_worker(node_id, control(Control(channel, ChannelControl::UnsubAuto))); + sim.process(1); + sim.control_worker(node_id, control(Control(channel, ChannelControl::PubData(value.clone())))); + sim.process(1); + assert_eq!(sim.pop_res_worker(), None); +} + +#[test] +fn feature_pubsub_manual_two_nodes() { let node1 = 1; let node2 = 2; let mut sim = NetworkSimulator::<(), (), (), ()>::new(0); @@ -65,7 +133,48 @@ fn feature_pubsub_two_nodes() { } #[test] -fn feature_pubsub_three_nodes() { +fn feature_pubsub_auto_two_nodes() { + let node1 = 1; + let node2 = 2; + let mut sim = NetworkSimulator::<(), (), (), ()>::new(0); + + let _addr1 = sim.add_node(TestNode::new(node1, 1234, vec![])); + let addr2 = sim.add_node(TestNode::new(node2, 1235, vec![])); + + sim.control(node1, ExtIn::ConnectTo(addr2)); + + // For sync + for _i in 0..4 { + sim.process(500); + } + + let channel = ChannelId(1000); + let value = vec![1, 2, 3, 4]; + + sim.control(node2, control(Control(channel, ChannelControl::PubStart))); + sim.control(node1, control(Control(channel, ChannelControl::SubAuto))); + sim.process(1); + + sim.control(node2, control(Control(channel, ChannelControl::PubData(value.clone())))); + sim.process(1); + assert_eq!(sim.pop_res(), Some((node1, event(Event(channel, ChannelEvent::SourceData(node2, value.clone())))))); + assert_eq!(sim.pop_res(), None); + + log::info!("Simulate feedback source now"); + sim.control(node1, control(Control(channel, ChannelControl::FeedbackAuto(Feedback::simple(0, 10, 1000, 2000))))); + sim.process(2000); //after that tick feedback will timeout + assert_eq!(sim.pop_res(), Some((node2, event(Event(channel, ChannelEvent::FeedbackData(Feedback::simple(0, 10, 1000, 2000))))))); + assert_eq!(sim.pop_res(), None); + + sim.control(node1, control(Control(channel, ChannelControl::UnsubAuto))); + sim.process(1); + sim.control(node2, control(Control(channel, ChannelControl::PubData(value)))); + sim.process(1); + assert_eq!(sim.pop_res(), None); +} + +#[test] +fn feature_pubsub_manual_three_nodes() { let node1 = 1; let node2 = 2; let node3 = 3; @@ -94,3 +203,93 @@ fn feature_pubsub_three_nodes() { assert_eq!(sim.pop_res(), Some((node1, event(Event(channel, ChannelEvent::SourceData(node3, value)))))); assert_eq!(sim.pop_res(), None); } + +#[test] +fn feature_pubsub_auto_three_nodes() { + let node1 = 1; + let node2 = 2; + let node3 = 3; + let mut sim = NetworkSimulator::<(), (), (), ()>::new(0); + + let _addr1 = sim.add_node(TestNode::new(node1, 1234, vec![])); + let addr2 = sim.add_node(TestNode::new(node2, 1235, vec![])); + let addr3 = sim.add_node(TestNode::new(node3, 1236, vec![])); + + sim.control(node1, ExtIn::ConnectTo(addr2)); + sim.control(node2, ExtIn::ConnectTo(addr3)); + + // For sync + for _i in 0..4 { + sim.process(500); + } + + let channel = ChannelId(1000); + let value = vec![1, 2, 3, 4]; + + sim.control(node1, control(Control(channel, ChannelControl::SubAuto))); + sim.process(1); + sim.control(node3, control(Control(channel, ChannelControl::PubStart))); + sim.process(1); + + sim.control(node3, control(Control(channel, ChannelControl::PubData(value.clone())))); + sim.process(1); + assert_eq!(sim.pop_res(), Some((node1, event(Event(channel, ChannelEvent::SourceData(node3, value.clone())))))); + assert_eq!(sim.pop_res(), None); + + log::info!("Simulate feedback source now"); + sim.control(node1, control(Control(channel, ChannelControl::FeedbackAuto(Feedback::simple(0, 10, 1000, 2000))))); + sim.process(2000); //after that tick feedback will timeout + assert_eq!(sim.pop_res(), Some((node3, event(Event(channel, ChannelEvent::FeedbackData(Feedback::simple(0, 10, 1000, 2000))))))); + assert_eq!(sim.pop_res(), None); + + sim.control(node1, control(Control(channel, ChannelControl::UnsubAuto))); + sim.process(1); + sim.control(node3, control(Control(channel, ChannelControl::PubData(value)))); + sim.process(1); + assert_eq!(sim.pop_res(), None); +} + +#[test] +fn feature_pubsub_auto_three_nodes_sub_after_start() { + let node1 = 1; + let node2 = 2; + let node3 = 3; + let mut sim = NetworkSimulator::<(), (), (), ()>::new(0); + + let _addr1 = sim.add_node(TestNode::new(node1, 1234, vec![])); + let addr2 = sim.add_node(TestNode::new(node2, 1235, vec![])); + let addr3 = sim.add_node(TestNode::new(node3, 1236, vec![])); + + sim.control(node1, ExtIn::ConnectTo(addr2)); + sim.control(node2, ExtIn::ConnectTo(addr3)); + + // For sync + for _i in 0..4 { + sim.process(500); + } + + let channel = ChannelId(1000); + let value = vec![1, 2, 3, 4]; + + sim.control(node3, control(Control(channel, ChannelControl::PubStart))); + sim.process(1); + sim.control(node1, control(Control(channel, ChannelControl::SubAuto))); + sim.process(1); + + sim.control(node3, control(Control(channel, ChannelControl::PubData(value.clone())))); + sim.process(1); + assert_eq!(sim.pop_res(), Some((node1, event(Event(channel, ChannelEvent::SourceData(node3, value.clone())))))); + assert_eq!(sim.pop_res(), None); + + log::info!("Simulate feedback source now"); + sim.control(node1, control(Control(channel, ChannelControl::FeedbackAuto(Feedback::simple(0, 10, 1000, 2000))))); + sim.process(2000); //after that tick feedback will timeout + assert_eq!(sim.pop_res(), Some((node3, event(Event(channel, ChannelEvent::FeedbackData(Feedback::simple(0, 10, 1000, 2000))))))); + assert_eq!(sim.pop_res(), None); + + sim.control(node1, control(Control(channel, ChannelControl::UnsubAuto))); + sim.process(1); + sim.control(node3, control(Control(channel, ChannelControl::PubData(value)))); + sim.process(1); + assert_eq!(sim.pop_res(), None); +} diff --git a/packages/network/tests/feature_router_sync.rs b/packages/network/tests/feature_router_sync.rs index 02698bff..eb347834 100644 --- a/packages/network/tests/feature_router_sync.rs +++ b/packages/network/tests/feature_router_sync.rs @@ -33,7 +33,7 @@ impl Service for MockService { struct MockServiceWorker; -impl ServiceWorker for MockServiceWorker { +impl ServiceWorker for MockServiceWorker { fn service_id(&self) -> u8 { 0 } @@ -58,7 +58,7 @@ impl ServiceBuilder for MockServ Box::new(MockService) } - fn create_worker(&self) -> Box> { + fn create_worker(&self) -> Box> { Box::new(MockServiceWorker) } } diff --git a/packages/network/tests/service_visualization.rs b/packages/network/tests/service_visualization.rs index 317441b6..29791bd3 100644 --- a/packages/network/tests/service_visualization.rs +++ b/packages/network/tests/service_visualization.rs @@ -11,18 +11,21 @@ use crate::simulator::{node_to_addr, NetworkSimulator, TestNode}; mod simulator; fn node_changed(node: NodeId, remotes: &[(NodeId, ConnId)]) -> ExtOut { - ExtOut::ServicesEvent(Event::NodeChanged( - node, - remotes - .iter() - .map(|(n, c)| ConnectionInfo { - conn: *c, - dest: *n, - remote: node_to_addr(*n), - rtt_ms: 0, - }) - .collect(), - )) + ExtOut::ServicesEvent( + visualization::SERVICE_ID.into(), + Event::NodeChanged( + node, + remotes + .iter() + .map(|(n, c)| ConnectionInfo { + conn: *c, + dest: *n, + remote: node_to_addr(*n), + rtt_ms: 0, + }) + .collect(), + ), + ) } #[test] @@ -41,8 +44,8 @@ fn service_visualization_simple() { sim.process(1000); } - assert_eq!(sim.pop_res(), Some((node1, ExtOut::ServicesEvent(Event::GotAll(vec![]))))); - assert_eq!(sim.pop_res(), Some((node1, ExtOut::ServicesEvent(Event::NodeChanged(node1, vec![]))))); + assert_eq!(sim.pop_res(), Some((node1, ExtOut::ServicesEvent(visualization::SERVICE_ID.into(), Event::GotAll(vec![]))))); + assert_eq!(sim.pop_res(), Some((node1, ExtOut::ServicesEvent(visualization::SERVICE_ID.into(), Event::NodeChanged(node1, vec![]))))); sim.control(node1, ExtIn::ConnectTo(addr2)); @@ -79,10 +82,10 @@ fn service_visualization_multi_collectors() { sim.process(1000); } - assert_eq!(sim.pop_res(), Some((node1, ExtOut::ServicesEvent(Event::GotAll(vec![]))))); - assert_eq!(sim.pop_res(), Some((node2, ExtOut::ServicesEvent(Event::GotAll(vec![]))))); - assert_eq!(sim.pop_res(), Some((node1, ExtOut::ServicesEvent(Event::NodeChanged(node1, vec![]))))); - assert_eq!(sim.pop_res(), Some((node2, ExtOut::ServicesEvent(Event::NodeChanged(node2, vec![]))))); + assert_eq!(sim.pop_res(), Some((node1, ExtOut::ServicesEvent(visualization::SERVICE_ID.into(), Event::GotAll(vec![]))))); + assert_eq!(sim.pop_res(), Some((node2, ExtOut::ServicesEvent(visualization::SERVICE_ID.into(), Event::GotAll(vec![]))))); + assert_eq!(sim.pop_res(), Some((node1, ExtOut::ServicesEvent(visualization::SERVICE_ID.into(), Event::NodeChanged(node1, vec![]))))); + assert_eq!(sim.pop_res(), Some((node2, ExtOut::ServicesEvent(visualization::SERVICE_ID.into(), Event::NodeChanged(node2, vec![]))))); sim.control(node1, ExtIn::ConnectTo(addr2)); sim.control(node2, ExtIn::ConnectTo(addr3)); @@ -107,7 +110,7 @@ fn service_visualization_multi_collectors() { let get_key = |a: &ExtOut| -> u32 { match a { - ExtOut::ServicesEvent(Event::NodeChanged(node, _)) => *node, + ExtOut::ServicesEvent(_service, Event::NodeChanged(node, _)) => *node, _ => panic!("Unexpected event: {:?}", a), } }; diff --git a/packages/network/tests/simulator.rs b/packages/network/tests/simulator.rs index aa379f41..9134c458 100644 --- a/packages/network/tests/simulator.rs +++ b/packages/network/tests/simulator.rs @@ -11,18 +11,20 @@ use std::{collections::VecDeque, net::IpAddr}; use atm0s_sdn_identity::{NodeAddr, NodeAddrBuilder, NodeId, Protocol}; use atm0s_sdn_network::base::ServiceBuilder; +use atm0s_sdn_network::controller_plane::ControllerPlaneCfg; +use atm0s_sdn_network::data_plane::DataPlaneCfg; use atm0s_sdn_network::features::{FeaturesControl, FeaturesEvent}; use atm0s_sdn_network::secure::{HandshakeBuilderXDA, StaticKeyAuthorization}; +use atm0s_sdn_network::worker::{SdnWorker, SdnWorkerCfg, SdnWorkerInput, SdnWorkerOutput}; use atm0s_sdn_network::{ - base::{GenericBuffer, GenericBufferMut}, - controller_plane::{self, ControllerPlane}, - data_plane::{self, DataPlane}, - ExtIn, ExtOut, + base::{Buffer, BufferMut}, + data_plane, ExtIn, ExtOut, }; use atm0s_sdn_router::shadow::ShadowRouterHistory; use log::{LevelFilter, Metadata, Record}; use parking_lot::Mutex; use rand::rngs::mock::StepRng; +use sans_io_runtime::TaskSwitcher; static CONTEXT_LOGGER: ContextLogger = ContextLogger { node: Mutex::new(None) }; @@ -74,16 +76,20 @@ impl Drop for AutoContext { #[derive(Debug)] pub enum TestNodeIn<'a, SC> { Ext(ExtIn), - Udp(SocketAddr, GenericBufferMut<'a>), - #[allow(unused)] - Tun(GenericBufferMut<'a>), + ExtWorker(ExtIn), + Udp(SocketAddr, BufferMut<'a>), + #[cfg(feature = "vpn")] + Tun(BufferMut<'a>), } #[derive(Debug)] pub enum TestNodeOut<'a, SE> { Ext(ExtOut), - Udp(Vec, GenericBuffer<'a>), - Tun(GenericBuffer<'a>), + ExtWorker(ExtOut), + Udp(Vec, Buffer<'a>), + #[cfg(feature = "vpn")] + Tun(Buffer<'a>), + Continue, } pub fn build_addr(node_id: NodeId) -> NodeAddr { @@ -101,10 +107,10 @@ struct SingleThreadDataWorkerHistory { impl ShadowRouterHistory for SingleThreadDataWorkerHistory { fn already_received_broadcast(&self, from: Option, service: u8, seq: u16) -> bool { - log::debug!("Check already_received_broadcast from {:?} service {} seq {}", from, service, seq); let mut map = self.map.lock(); let mut queue = self.queue.lock(); if map.contains_key(&(from, service, seq)) { + log::debug!("already_received_broadcast from {:?} service {} seq {}", from, service, seq); return true; } map.insert((from, service, seq), true); @@ -120,19 +126,34 @@ impl ShadowRouterHistory for SingleThreadDataWorkerHistory { pub struct TestNode { node_id: NodeId, - controller: ControllerPlane, - worker: DataPlane, + worker: SdnWorker, } -impl TestNode { +impl TestNode { pub fn new(node_id: NodeId, session: u64, services: Vec>>) -> Self { let _log = AutoContext::new(node_id); - let auth = Arc::new(StaticKeyAuthorization::new("demo-key")); - let handshake = Arc::new(HandshakeBuilderXDA); - let rd = Box::new(StepRng::new(1000, 5)); - let controller = ControllerPlane::new(node_id, session, services.clone(), auth, handshake, rd); - let worker = DataPlane::new(node_id, services, Arc::new(SingleThreadDataWorkerHistory::default())); - Self { node_id, controller, worker } + let authorization: Arc = Arc::new(StaticKeyAuthorization::new("demo-key")); + let handshake_builder = Arc::new(HandshakeBuilderXDA); + let random = Box::new(StepRng::new(1000, 5)); + Self { + node_id, + worker: SdnWorker::new(SdnWorkerCfg { + node_id, + tick_ms: 1, + controller: Some(ControllerPlaneCfg { + session, + services: services.clone(), + authorization, + handshake_builder, + random, + }), + data: DataPlaneCfg { + worker_id: 0, + services, + history: Arc::new(SingleThreadDataWorkerHistory::default()), + }, + }), + } } pub fn node_id(&self) -> NodeId { @@ -143,80 +164,58 @@ impl TestNode { build_addr(self.node_id) } - pub fn tick(&mut self, now: u64) { + pub fn tick<'a>(&mut self, now: u64) -> Option> { let _log = AutoContext::new(self.node_id); - self.controller.on_tick(now); - self.worker.on_tick(now); + let out = self.worker.on_tick(now)?; + Some(self.process_worker_output(now, out)) } pub fn on_input<'a>(&mut self, now: u64, input: TestNodeIn<'a, SC>) -> Option> { let _log = AutoContext::new(self.node_id); match input { TestNodeIn::Ext(ext_in) => { - self.controller.on_event(now, controller_plane::Input::Ext(ext_in)); - let out = self.controller.pop_output(now)?; - self.process_controller_output(now, out) + let out = self.worker.on_event(now, SdnWorkerInput::Ext(ext_in))?; + Some(self.process_worker_output(now, out)) + } + TestNodeIn::ExtWorker(ext_in) => { + let out = self.worker.on_event(now, SdnWorkerInput::ExtWorker(ext_in))?; + Some(self.process_worker_output(now, out)) } TestNodeIn::Udp(addr, buf) => { - let out = self.worker.on_event(now, data_plane::Input::Net(data_plane::NetInput::UdpPacket(addr, buf)))?; - self.process_worker_output(now, out) + let out = self.worker.on_event(now, SdnWorkerInput::Net(data_plane::NetInput::UdpPacket(addr, buf)))?; + Some(self.process_worker_output(now, out)) } + #[cfg(feature = "vpn")] TestNodeIn::Tun(buf) => { - let out = self.worker.on_event(now, data_plane::Input::Net(data_plane::NetInput::TunPacket(buf)))?; - self.process_worker_output(now, out) + let out = self.worker.on_event(now, SdnWorkerInput::Net(data_plane::NetInput::TunPacket(buf)))?; + Some(self.process_worker_output(now, out)) } } } pub fn pop_output<'a>(&mut self, now: u64) -> Option> { let _log = AutoContext::new(self.node_id); - let mut keep_running = true; - while keep_running { - keep_running = false; - - if let Some(output) = self.controller.pop_output(now) { - keep_running = true; - if let Some(out) = self.process_controller_output(now, output) { - return Some(out); - } - } - - if let Some(output) = self.worker.pop_output(now) { - keep_running = true; - if let Some(out) = self.process_worker_output(now, output) { - return Some(out); - } - } - } - None + let output = self.worker.pop_output(now)?; + Some(self.process_worker_output(now, output)) } - fn process_controller_output<'a>(&mut self, now: u64, output: controller_plane::Output) -> Option> { + fn process_worker_output<'a>(&mut self, now: u64, output: SdnWorkerOutput<'a, SC, SE, TC, TW>) -> TestNodeOut<'a, SE> { match output { - controller_plane::Output::Event(e) => { - let output = self.worker.on_event(now, data_plane::Input::Event(e))?; - self.process_worker_output(now, output) - } - controller_plane::Output::Ext(out) => Some(TestNodeOut::Ext(out)), - controller_plane::Output::ShutdownSuccess => None, - } - } - - fn process_worker_output<'a>(&mut self, now: u64, output: data_plane::Output<'a, SE, TC>) -> Option> { - match output { - data_plane::Output::Ext(out) => Some(TestNodeOut::Ext(out)), - data_plane::Output::Control(control) => { - self.controller.on_event(now, controller_plane::Input::Control(control)); - let output = self.controller.pop_output(now)?; - self.process_controller_output(now, output) + SdnWorkerOutput::Ext(ext) => TestNodeOut::Ext(ext), + SdnWorkerOutput::ExtWorker(ext) => TestNodeOut::ExtWorker(ext), + SdnWorkerOutput::Net(data_plane::NetOutput::UdpPacket(dest, data)) => TestNodeOut::Udp(vec![dest], data), + SdnWorkerOutput::Net(data_plane::NetOutput::UdpPackets(dests, data)) => TestNodeOut::Udp(dests, data), + #[cfg(feature = "vpn")] + SdnWorkerOutput::Net(data_plane::NetOutput::TunPacket(data)) => TestNodeOut::Tun(data), + SdnWorkerOutput::Bus(bus) => { + if let Some(out) = self.worker.on_event(now, SdnWorkerInput::Bus(bus)) { + self.process_worker_output(now, out) + } else { + TestNodeOut::Continue + } } - data_plane::Output::Net(out) => match out { - data_plane::NetOutput::UdpPacket(dest, buf) => Some(TestNodeOut::Udp(vec![dest], buf)), - data_plane::NetOutput::UdpPackets(dest, buf) => Some(TestNodeOut::Udp(dest, buf)), - data_plane::NetOutput::TunPacket(buf) => Some(TestNodeOut::Tun(buf)), - }, - data_plane::Output::ShutdownResponse => None, - data_plane::Output::Continue => None, + SdnWorkerOutput::ShutdownResponse => todo!(), + SdnWorkerOutput::Continue => TestNodeOut::Continue, } } } @@ -232,19 +231,25 @@ pub fn node_to_addr(node: NodeId) -> SocketAddr { pub struct NetworkSimulator { clock_ms: u64, input: VecDeque<(NodeId, ExtIn)>, + input_worker: VecDeque<(NodeId, ExtIn)>, output: VecDeque<(NodeId, ExtOut)>, + output_worker: VecDeque<(NodeId, ExtOut)>, nodes: Vec>, nodes_index: HashMap, + switcher: TaskSwitcher, } -impl NetworkSimulator { +impl NetworkSimulator { pub fn new(started_ms: u64) -> Self { Self { clock_ms: started_ms, input: VecDeque::new(), output: VecDeque::new(), + input_worker: VecDeque::new(), + output_worker: VecDeque::new(), nodes: Vec::new(), nodes_index: HashMap::new(), + switcher: TaskSwitcher::new(0), } } @@ -262,83 +267,84 @@ impl NetworkSimulator { self.output.pop_front() } + pub fn control_worker(&mut self, node: NodeId, control: ExtIn) { + self.input_worker.push_back((node, control)); + } + + pub fn pop_res_worker(&mut self) -> Option<(NodeId, ExtOut)> { + self.output_worker.pop_front() + } + pub fn add_node(&mut self, node: TestNode) -> NodeAddr { let index = self.nodes.len(); self.nodes_index.insert(node.node_id(), index); let addr = node.addr(); self.nodes.push(node); + self.switcher.set_tasks(self.nodes.len()); addr } pub fn process(&mut self, delta: u64) { self.clock_ms += delta; log::debug!("Tick {} ms", self.clock_ms); - for node in self.nodes.iter_mut() { - node.tick(self.clock_ms); + for i in 0..self.nodes.len() { + let node_id = self.nodes[i].node_id(); + if let Some(out) = self.nodes[i].tick(self.clock_ms) { + self.process_out(self.clock_ms, node_id, out); + } } - self.pop_outputs(); - - if !self.input.is_empty() { - while let Some((node, input)) = self.input.pop_front() { - self.process_input(node, TestNodeIn::Ext(input)); + while let Some((node, input)) = self.input.pop_front() { + let node_index = *self.nodes_index.get(&node).expect("Node not found"); + if let Some(out) = self.nodes[node_index].on_input(self.clock_ms, TestNodeIn::Ext(input)) { + self.process_out(self.clock_ms, node, out); } - - self.pop_outputs(); } - } - fn process_input<'a>(&mut self, node: NodeId, input: TestNodeIn<'a, SC>) -> Option<()> { - let index = self.nodes_index.get(&node).expect("Node not found"); - let output = self.nodes[*index].on_input(self.clock_ms, input)?; - match output { - TestNodeOut::Ext(out) => { - self.output.push_back((node, out)); - Some(()) + while let Some((node, input)) = self.input_worker.pop_front() { + let node_index = *self.nodes_index.get(&node).expect("Node not found"); + if let Some(out) = self.nodes[node_index].on_input(self.clock_ms, TestNodeIn::ExtWorker(input)) { + self.process_out(self.clock_ms, node, out); } - TestNodeOut::Udp(dests, data) => { - let source_addr = node_to_addr(node); - for dest in dests { - let dest_node = addr_to_node(dest); - self.process_input(dest_node, TestNodeIn::Udp(source_addr, data.clone_mut())); - } - Some(()) - } - TestNodeOut::Tun(_) => todo!(), } + + self.switcher.queue_flag_all(); + self.pop_outputs(self.clock_ms); } - fn pop_outputs(&mut self) { - let mut keep_running = true; - while keep_running { - keep_running = false; - for index in 0..self.nodes.len() { - let node = self.nodes[index].node_id(); - if self.pop_output(node).is_some() { - keep_running = true; - } + fn pop_outputs(&mut self, now: u64) { + while let Some(index) = self.switcher.queue_current() { + let node = self.nodes[index].node_id(); + if let Some(out) = self.switcher.queue_process(self.nodes[index].pop_output(now)) { + self.process_out(now, node, out); } } } - fn pop_output<'a>(&mut self, node: NodeId) -> Option<()> { - let index = self.nodes_index.get(&node).expect("Node not found"); - let output = self.nodes[*index].pop_output(self.clock_ms)?; - match output { + fn process_out(&mut self, now: u64, node: NodeId, out: TestNodeOut) { + let node_index = *self.nodes_index.get(&node).expect("Node not found"); + self.switcher.queue_flag_task(node_index); + match out { TestNodeOut::Ext(out) => { self.output.push_back((node, out)); - Some(()) + } + TestNodeOut::ExtWorker(out) => { + self.output_worker.push_back((node, out)); } TestNodeOut::Udp(dests, data) => { let source_addr = node_to_addr(node); for dest in dests { log::debug!("Send UDP packet from {} to {}, buf len {}", source_addr, dest, data.len()); let dest_node = addr_to_node(dest); - self.process_input(dest_node, TestNodeIn::Udp(source_addr, data.clone_mut())); + let dest_index = *self.nodes_index.get(&dest_node).expect("Node not found"); + if let Some(out) = self.nodes[dest_index].on_input(now, TestNodeIn::Udp(source_addr, data.clone_mut())) { + self.process_out(now, dest_node, out); + } } - Some(()) } + #[cfg(feature = "vpn")] TestNodeOut::Tun(_) => todo!(), + TestNodeOut::Continue => {} } } } diff --git a/packages/runner/Cargo.toml b/packages/runner/Cargo.toml index 94efc81d..9eb077f9 100644 --- a/packages/runner/Cargo.toml +++ b/packages/runner/Cargo.toml @@ -10,8 +10,7 @@ license = "MIT" [dependencies] thiserror = { workspace = true } -sans-io-runtime = { workspace = true, features = ["poll-backend", "polling-backend", "mio-backend", "udp"] } -# sans-io-runtime = { path = "../../../sans-io-runtime", features = ["poll-backend", "polling-backend", "mio-backend", "udp"]} +sans-io-runtime = { workspace = true, features = ["poll-backend", "polling-backend", "udp"] } atm0s-sdn-identity = { path = "../core/identity", version = "0.2.0" } atm0s-sdn-router = { path = "../core/router", version = "0.1.4" } atm0s-sdn-network = { path = "../network", version = "0.3.1" } diff --git a/packages/runner/examples/simple_kv.rs b/packages/runner/examples/simple_kv.rs index a527f101..3c182a65 100644 --- a/packages/runner/examples/simple_kv.rs +++ b/packages/runner/examples/simple_kv.rs @@ -8,10 +8,7 @@ use atm0s_sdn_network::{ services::visualization, }; use clap::{Parser, ValueEnum}; -use sans_io_runtime::{ - backend::{MioBackend, PollBackend, PollingBackend}, - Owner, -}; +use sans_io_runtime::backend::{PollBackend, PollingBackend}; use std::{ sync::{ atomic::{AtomicBool, Ordering}, @@ -20,13 +17,12 @@ use std::{ time::Duration, }; -use atm0s_sdn::{builder::SdnBuilder, tasks::SdnExtIn}; +use atm0s_sdn::{SdnBuilder, SdnExtIn, SdnOwner}; #[derive(Debug, Clone, ValueEnum)] enum BackendType { Poll, Polling, - Mio, } /// Simple program to running a node @@ -89,13 +85,12 @@ fn main() { } let mut controller = match args.backend { - BackendType::Mio => builder.build::>(args.workers), - BackendType::Poll => builder.build::>(args.workers), - BackendType::Polling => builder.build::>(args.workers), + BackendType::Poll => builder.build::>(args.workers), + BackendType::Polling => builder.build::>(args.workers), }; if args.kv_subscribe { - controller.send_to(Owner::worker(0), SdnExtIn::FeaturesControl(FeaturesControl::DhtKv(Control::MapCmd(Map(args.kv_map), MapControl::Sub)))); + controller.send_to(0, SdnExtIn::FeaturesControl(FeaturesControl::DhtKv(Control::MapCmd(Map(args.kv_map), MapControl::Sub)))); } let mut i: u32 = 0; @@ -114,10 +109,7 @@ fn main() { if i % 1000 == 0 && args.kv_set { let data = i.to_be_bytes().to_vec(); log::info!("Set key: {:?}", data); - controller.send_to( - Owner::worker(0), - SdnExtIn::FeaturesControl(FeaturesControl::DhtKv(Control::MapCmd(Map(args.kv_map), MapControl::Set(Key(200), data)))), - ); + controller.send_to(0, SdnExtIn::FeaturesControl(FeaturesControl::DhtKv(Control::MapCmd(Map(args.kv_map), MapControl::Set(Key(200), data))))); } while let Some(out) = controller.pop_event() { log::info!("Got event: {:?}", out); diff --git a/packages/runner/examples/simple_node.rs b/packages/runner/examples/simple_node.rs index c814211d..8a841931 100644 --- a/packages/runner/examples/simple_node.rs +++ b/packages/runner/examples/simple_node.rs @@ -1,7 +1,7 @@ use atm0s_sdn_identity::{NodeAddr, NodeId}; use atm0s_sdn_network::{secure::StaticKeyAuthorization, services::visualization}; use clap::{Parser, ValueEnum}; -use sans_io_runtime::backend::{MioBackend, PollBackend, PollingBackend}; +use sans_io_runtime::backend::{PollBackend, PollingBackend}; use std::{ net::SocketAddr, sync::{ @@ -11,13 +11,12 @@ use std::{ time::Duration, }; -use atm0s_sdn::builder::SdnBuilder; +use atm0s_sdn::{SdnBuilder, SdnOwner}; #[derive(Debug, Clone, ValueEnum)] enum BackendType { Poll, Polling, - Mio, } /// Simple program to running a node @@ -91,9 +90,8 @@ fn main() { } let mut controller = match args.backend { - BackendType::Mio => builder.build::>(args.workers), - BackendType::Poll => builder.build::>(args.workers), - BackendType::Polling => builder.build::>(args.workers), + BackendType::Poll => builder.build::>(args.workers), + BackendType::Polling => builder.build::>(args.workers), }; while controller.process().is_some() { diff --git a/packages/runner/run-example-debug.sh b/packages/runner/run-example-debug.sh new file mode 100644 index 00000000..54bcb4d5 --- /dev/null +++ b/packages/runner/run-example-debug.sh @@ -0,0 +1,3 @@ +export RUST_LOG=info +cargo build --features vpn --example simple_node +sudo --preserve-env=RUST_LOG ../../target/debug/examples/simple_node $@ diff --git a/packages/runner/run-example-release.sh b/packages/runner/run-example-release.sh new file mode 100644 index 00000000..93c2edfb --- /dev/null +++ b/packages/runner/run-example-release.sh @@ -0,0 +1,3 @@ +export RUST_LOG=info +cargo build --release --features vpn --example simple_node +sudo --preserve-env=RUST_LOG ../../target/release/examples/simple_node $@ diff --git a/packages/runner/src/builder.rs b/packages/runner/src/builder.rs index df4b1fe9..69f241b5 100644 --- a/packages/runner/src/builder.rs +++ b/packages/runner/src/builder.rs @@ -13,9 +13,12 @@ use atm0s_sdn_network::{ services::{manual_discovery, visualization}, }; use rand::{thread_rng, RngCore}; -use sans_io_runtime::{backend::Backend, Owner}; +use sans_io_runtime::backend::Backend; -use crate::tasks::{ControllerCfg, DataWorkerHistory, SdnController, SdnExtIn, SdnInnerCfg, SdnWorkerInner}; +use crate::{ + history::DataWorkerHistory, + worker_inner::{ControllerCfg, SdnController, SdnExtIn, SdnInnerCfg, SdnOwner, SdnWorkerInner}, +}; pub struct SdnBuilder { auth: Option>, @@ -24,6 +27,7 @@ pub struct SdnBuilder { node_id: NodeId, session: u64, udp_port: u16, + tick_ms: u64, visualization_collector: bool, seeds: Vec, services: Vec>>, @@ -77,6 +81,7 @@ where handshake: None, node_addr, node_id, + tick_ms: 1000, session: thread_rng().next_u64(), udp_port, visualization_collector: false, @@ -142,7 +147,7 @@ where self.vpn_netmask = Some(netmask); } - pub fn build(mut self, workers: usize) -> SdnController { + pub fn build>(mut self, workers: usize) -> SdnController { assert!(workers > 0); #[cfg(feature = "vpn")] let (tun_device, mut queue_fds) = { @@ -165,10 +170,11 @@ where let history = Arc::new(DataWorkerHistory::default()); let mut controller = SdnController::default(); - controller.add_worker::<_, SdnWorkerInner, B>( + controller.add_worker::, B>( Duration::from_millis(1000), SdnInnerCfg { node_id: self.node_id, + tick_ms: self.tick_ms, udp_port: self.udp_port, services: self.services.clone(), history: history.clone(), @@ -187,10 +193,11 @@ where ); for _ in 1..workers { - controller.add_worker::<_, SdnWorkerInner, B>( + controller.add_worker::, B>( Duration::from_millis(1000), SdnInnerCfg { node_id: self.node_id, + tick_ms: self.tick_ms, udp_port: self.udp_port, services: self.services.clone(), history: history.clone(), @@ -205,7 +212,7 @@ where std::thread::sleep(std::time::Duration::from_millis(100)); for seed in self.seeds { - controller.send_to(Owner::worker(0), SdnExtIn::ConnectTo(seed)); + controller.send_to(0, SdnExtIn::ConnectTo(seed)); } controller diff --git a/packages/runner/src/tasks/data_plane/history.rs b/packages/runner/src/history.rs similarity index 97% rename from packages/runner/src/tasks/data_plane/history.rs rename to packages/runner/src/history.rs index 9b3b071b..3c3cf36e 100644 --- a/packages/runner/src/tasks/data_plane/history.rs +++ b/packages/runner/src/history.rs @@ -54,7 +54,7 @@ impl ShadowRouterHistory for DataWorkerHistory { mod tests { use atm0s_sdn_router::shadow::ShadowRouterHistory; - use crate::tasks::data_plane::history::HISTORY_TIMEOUT_MS; + use crate::history::HISTORY_TIMEOUT_MS; use super::DataWorkerHistory; diff --git a/packages/runner/src/lib.rs b/packages/runner/src/lib.rs index 40360b06..d958a078 100644 --- a/packages/runner/src/lib.rs +++ b/packages/runner/src/lib.rs @@ -1,12 +1,46 @@ pub use atm0s_sdn_identity::{ConnDirection, ConnId, NodeAddr, NodeAddrBuilder, NodeId, NodeIdType, Protocol}; -pub use atm0s_sdn_network::base; +pub use atm0s_sdn_network::controller_plane::ControllerPlaneCfg; pub use atm0s_sdn_network::convert_enum; -pub use atm0s_sdn_network::features; -pub use atm0s_sdn_network::secure; -pub use atm0s_sdn_network::services; -pub use atm0s_sdn_router::ServiceBroadcastLevel; +pub use atm0s_sdn_network::data_plane::DataPlaneCfg; +use atm0s_sdn_network::features::FeaturesControl; +pub use atm0s_sdn_network::{ + base, features, secure, services, + worker::{SdnWorker, SdnWorkerBusEvent, SdnWorkerCfg, SdnWorkerInput, SdnWorkerOutput}, +}; +pub use atm0s_sdn_network::{ + base::ServiceId, + data_plane::{NetInput, NetOutput}, +}; +pub use atm0s_sdn_router::{shadow::ShadowRouterHistory, ServiceBroadcastLevel}; pub use sans_io_runtime; -pub mod builder; -pub mod tasks; -pub mod time; +mod builder; +mod history; +mod time; +mod worker_inner; + +pub use builder::SdnBuilder; +pub use history::DataWorkerHistory; +pub use time::{TimePivot, TimeTicker}; +pub use worker_inner::{SdnChannel, SdnController, SdnEvent, SdnExtIn, SdnExtOut, SdnOwner}; + +pub trait SdnControllerUtils { + fn connect_to(&mut self, addr: NodeAddr); + fn feature_control(&mut self, cmd: FeaturesControl); + fn service_control(&mut self, service: ServiceId, cmd: SC); +} + +impl SdnControllerUtils + for SdnController +{ + fn connect_to(&mut self, addr: NodeAddr) { + self.send_to(0, SdnExtIn::ConnectTo(addr)); + } + fn feature_control(&mut self, cmd: FeaturesControl) { + self.send_to(0, SdnExtIn::FeaturesControl(cmd)); + } + + fn service_control(&mut self, service: ServiceId, cmd: SC) { + self.send_to(0, SdnExtIn::ServicesControl(service, cmd)); + } +} diff --git a/packages/runner/src/tasks/controller_plane.rs b/packages/runner/src/tasks/controller_plane.rs deleted file mode 100644 index 2981f9b5..00000000 --- a/packages/runner/src/tasks/controller_plane.rs +++ /dev/null @@ -1,106 +0,0 @@ -use std::{collections::VecDeque, sync::Arc, time::Instant}; - -use atm0s_sdn_identity::NodeId; -use atm0s_sdn_network::{ - base::{Authorization, HandshakeBuilder, ServiceBuilder}, - controller_plane::{ControllerPlane, Input as ControllerInput, Output as ControllerOutput}, - features::{FeaturesControl, FeaturesEvent}, - ExtIn, ExtOut, LogicControl, LogicEvent, -}; -use atm0s_sdn_router::shadow::ShadowRouterHistory; -use rand::rngs::ThreadRng; -use sans_io_runtime::{bus::BusEvent, Task, TaskInput, TaskOutput}; - -use crate::time::{TimePivot, TimeTicker}; - -pub type ChannelIn = (); -pub type ChannelOut = (); - -pub struct ControllerPlaneCfg { - pub node_id: NodeId, - pub session: u64, - pub tick_ms: u64, - pub auth: Arc, - pub handshake: Arc, - pub history: Arc, - pub services: Vec>>, - #[cfg(feature = "vpn")] - pub vpn_tun_device: Option, -} - -pub type EventIn = LogicControl; -pub type EventOut = LogicEvent; - -pub struct ControllerPlaneTask { - #[allow(unused)] - node_id: NodeId, - controller: ControllerPlane, - queue: VecDeque, ChannelIn, ChannelOut, EventOut>>, - ticker: TimeTicker, - timer: TimePivot, - history: Arc, - #[cfg(feature = "vpn")] - _vpn_tun_device: Option, -} - -impl ControllerPlaneTask { - pub fn build(cfg: ControllerPlaneCfg) -> Self { - Self { - node_id: cfg.node_id, - controller: ControllerPlane::new(cfg.node_id, cfg.session, cfg.services, cfg.auth, cfg.handshake, Box::new(ThreadRng::default())), - queue: VecDeque::from([TaskOutput::Bus(BusEvent::ChannelSubscribe(()))]), - ticker: TimeTicker::build(1000), - timer: TimePivot::build(), - history: cfg.history, - #[cfg(feature = "vpn")] - _vpn_tun_device: cfg.vpn_tun_device, - } - } -} - -impl Task, ExtOut, ChannelIn, ChannelOut, EventIn, EventOut> for ControllerPlaneTask { - /// The type identifier for the task. - const TYPE: u16 = 0; - - fn on_tick<'a>(&mut self, now: Instant) -> Option, ChannelIn, ChannelOut, EventOut>> { - if self.ticker.tick(now) { - self.controller.on_tick(self.timer.timestamp_ms(now)); - self.history.set_ts(self.timer.timestamp_ms(now)); - } - self.pop_output(now) - } - - fn on_event<'a>(&mut self, now: Instant, input: TaskInput<'a, ExtIn, ChannelIn, EventIn>) -> Option, ChannelIn, ChannelOut, EventOut>> { - let now_ms = self.timer.timestamp_ms(now); - match input { - TaskInput::Bus(_, event) => { - self.controller.on_event(now_ms, ControllerInput::Control(event)); - } - TaskInput::Ext(event) => { - self.controller.on_event(now_ms, ControllerInput::Ext(event)); - } - _ => { - panic!("Invalid input type for ControllerPlane") - } - }; - self.pop_output(now) - } - - fn pop_output<'a>(&mut self, now: Instant) -> Option, ChannelIn, ChannelOut, EventOut>> { - let now_ms = self.timer.timestamp_ms(now); - if let Some(output) = self.queue.pop_front() { - return Some(output); - } - let output = self.controller.pop_output(now_ms)?; - match output { - ControllerOutput::Ext(event) => Some(TaskOutput::Ext(event)), - ControllerOutput::ShutdownSuccess => Some(TaskOutput::Destroy), - ControllerOutput::Event(bus) => Some(TaskOutput::Bus(BusEvent::ChannelPublish((), true, bus))), - } - } - - fn shutdown<'a>(&mut self, now: Instant) -> Option, ChannelIn, ChannelOut, EventOut>> { - self.controller.on_event(self.timer.timestamp_ms(now), ControllerInput::ShutdownRequest); - self.pop_output(now) - } -} diff --git a/packages/runner/src/tasks/data_plane.rs b/packages/runner/src/tasks/data_plane.rs deleted file mode 100644 index 5c4aa747..00000000 --- a/packages/runner/src/tasks/data_plane.rs +++ /dev/null @@ -1,214 +0,0 @@ -use std::{ - collections::VecDeque, - fmt::Debug, - net::{Ipv4Addr, SocketAddr, SocketAddrV4}, - sync::Arc, - time::Instant, -}; - -use atm0s_sdn_identity::NodeId; -use atm0s_sdn_network::{ - base::{GenericBuffer, ReadOnlyBuffer, ServiceBuilder}, - data_plane::{DataPlane, Input as DataPlaneInput, NetInput, NetOutput, Output as DataPlaneOutput}, - features::{FeaturesControl, FeaturesEvent}, - ExtOut, LogicControl, LogicEvent, -}; -use atm0s_sdn_router::shadow::ShadowRouterHistory; -use sans_io_runtime::{bus::BusEvent, Buffer, NetIncoming, NetOutgoing, Task, TaskInput, TaskOutput}; - -use crate::time::TimePivot; - -pub(crate) mod history; - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub enum ChannelIn { - Broadcast, - Worker(u16), -} - -pub type ChannelOut = (); - -pub type EventIn = LogicEvent; -pub type EventOut = LogicControl; - -pub struct DataPlaneCfg { - pub history: Arc, - pub worker: u16, - pub node_id: NodeId, - pub port: u16, - pub services: Vec>>, - #[cfg(feature = "vpn")] - pub vpn_tun_fd: Option, -} - -pub struct DataPlaneTask { - #[allow(unused)] - node_id: NodeId, - worker: u16, - data_plane: DataPlane, - backend_udp_slot: usize, - timer: TimePivot, - #[cfg(feature = "vpn")] - backend_tun_slot: usize, - queue: VecDeque, ChannelIn, ChannelOut, EventOut>>, -} - -impl DataPlaneTask { - pub fn build(cfg: DataPlaneCfg) -> Self { - let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), cfg.port)); - let mut queue = VecDeque::from([ - TaskOutput::Net(NetOutgoing::UdpListen { addr, reuse: true }), - TaskOutput::Bus(BusEvent::ChannelSubscribe(ChannelIn::Broadcast)), - TaskOutput::Bus(BusEvent::ChannelSubscribe(ChannelIn::Worker(cfg.worker))), - ]); - #[cfg(feature = "vpn")] - if let Some(fd) = cfg.vpn_tun_fd { - queue.push_back(TaskOutput::Net(NetOutgoing::TunBind { fd })); - } - Self { - node_id: cfg.node_id, - worker: cfg.worker, - data_plane: DataPlane::new(cfg.node_id, cfg.services, cfg.history), - backend_udp_slot: 0, - timer: TimePivot::build(), - #[cfg(feature = "vpn")] - backend_tun_slot: 0, - queue, - } - } - - fn convert_output<'a>(&mut self, _now: Instant, output: DataPlaneOutput<'a, SE, TC>) -> Option, ChannelIn, ChannelOut, EventOut>> { - match output { - DataPlaneOutput::Net(NetOutput::UdpPacket(to, buf)) => Some(TaskOutput::Net(NetOutgoing::UdpPacket { - slot: self.backend_udp_slot, - to, - data: convert_buf1(buf), - })), - #[cfg(feature = "vpn")] - DataPlaneOutput::Net(NetOutput::TunPacket(buf)) => Some(TaskOutput::Net(NetOutgoing::TunPacket { - slot: self.backend_tun_slot, - data: convert_buf1(buf), - })), - #[cfg(not(feature = "vpn"))] - DataPlaneOutput::Net(NetOutput::TunPacket(_)) => None, - DataPlaneOutput::Net(NetOutput::UdpPackets(to, buf)) => Some(TaskOutput::Net(NetOutgoing::UdpPackets { - slot: self.backend_udp_slot, - to, - data: convert_buf1(buf), - })), - DataPlaneOutput::Control(bus) => Some(TaskOutput::Bus(BusEvent::ChannelPublish((), true, bus))), - DataPlaneOutput::ShutdownResponse => { - self.queue.push_back(TaskOutput::Net(NetOutgoing::UdpUnlisten { slot: self.backend_udp_slot })); - self.queue.push_back(TaskOutput::Bus(BusEvent::ChannelUnsubscribe(ChannelIn::Broadcast))); - self.queue.push_back(TaskOutput::Bus(BusEvent::ChannelUnsubscribe(ChannelIn::Worker(self.worker)))); - self.queue.push_back(TaskOutput::Destroy); - self.queue.pop_front() - } - DataPlaneOutput::Ext(ext) => Some(TaskOutput::Ext(ext)), - DataPlaneOutput::Continue => None, - } - } - - fn try_process_output<'a>(&mut self, now: Instant, output: DataPlaneOutput<'a, SE, TC>) -> Option, ChannelIn, ChannelOut, EventOut>> { - let out = self.convert_output(now, output); - if out.is_some() { - return out; - } - self.pop_output_direct(now) - } - - fn pop_output_direct<'a>(&mut self, now: Instant) -> Option, ChannelIn, ChannelOut, EventOut>> { - // self.pop_output_direct(now) - let now_ms = self.timer.timestamp_ms(now); - loop { - let output = self.data_plane.pop_output(now_ms)?; - let out = self.convert_output(now, output); - if out.is_some() { - return out; - } - } - } -} - -impl Task<(), ExtOut, ChannelIn, ChannelOut, EventIn, EventOut> for DataPlaneTask { - /// The type identifier for the task. - const TYPE: u16 = 1; - - fn on_tick<'a>(&mut self, now: Instant) -> Option, ChannelIn, ChannelOut, EventOut>> { - if let Some(out) = self.queue.pop_front() { - return Some(out); - } - - let now_ms = self.timer.timestamp_ms(now); - self.data_plane.on_tick(now_ms); - self.pop_output(now) - } - - fn on_event<'a>(&mut self, now: Instant, input: TaskInput<'a, (), ChannelIn, EventIn>) -> Option, ChannelIn, ChannelOut, EventOut>> { - match input { - TaskInput::Ext(_) => None, - TaskInput::Net(net) => match net { - NetIncoming::UdpListenResult { bind, result } => { - let res = result.expect("Should to bind UDP socket"); - self.backend_udp_slot = res.1; - log::info!("Data plane task bound udp {} to {}", bind, res.0); - None - } - NetIncoming::UdpPacket { slot: _, from, data } => { - let now_ms = self.timer.timestamp_ms(now); - let out = self.data_plane.on_event(now_ms, DataPlaneInput::Net(NetInput::UdpPacket(from, data.into())))?; - self.try_process_output(now, out) - } - #[cfg(feature = "vpn")] - NetIncoming::TunBindResult { result } => { - let res = result.expect("Should to bind TUN device"); - self.backend_tun_slot = res; - log::info!("Data plane task bound tun to {}", res); - None - } - #[cfg(feature = "vpn")] - NetIncoming::TunPacket { slot: _, data } => { - let now_ms = self.timer.timestamp_ms(now); - let out = self.data_plane.on_event(now_ms, DataPlaneInput::Net(NetInput::TunPacket(data.into())))?; - self.try_process_output(now, out) - } - #[cfg(not(feature = "vpn"))] - NetIncoming::TunBindResult { .. } => None, - #[cfg(not(feature = "vpn"))] - NetIncoming::TunPacket { .. } => None, - }, - TaskInput::Bus(_, event) => { - let output = self.data_plane.on_event(self.timer.timestamp_ms(now), DataPlaneInput::Event(event))?; - self.try_process_output(now, output) - } - } - } - - fn pop_output<'a>(&mut self, now: Instant) -> Option, ChannelIn, ChannelOut, EventOut>> { - if let Some(output) = self.queue.pop_front() { - return Some(output); - } - - self.pop_output_direct(now) - } - - fn shutdown<'a>(&mut self, now: Instant) -> Option, ChannelIn, ChannelOut, EventOut>> { - let output = self.data_plane.on_event(self.timer.timestamp_ms(now), DataPlaneInput::ShutdownRequest)?; - self.try_process_output(now, output) - } -} - -fn convert_buf1<'a>(buf1: GenericBuffer<'a>) -> Buffer<'a> { - match buf1.buf { - ReadOnlyBuffer::Ref(buf) => Buffer::Ref(&buf[buf1.range.clone()]), - ReadOnlyBuffer::Vec(mut buf) => { - if buf1.range.start == 0 { - buf.truncate(buf1.range.end); - Buffer::Vec(buf) - } else { - //TODO optimize this case for avoiding copy - Buffer::Vec(buf[buf1.range.clone()].to_vec()) - } - } - } -} diff --git a/packages/runner/src/tasks/event_convert/controller_plane.rs b/packages/runner/src/tasks/event_convert/controller_plane.rs deleted file mode 100644 index 13c61a60..00000000 --- a/packages/runner/src/tasks/event_convert/controller_plane.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::fmt::Debug; - -use atm0s_sdn_network::{ExtIn, ExtOut}; -use sans_io_runtime::{bus::BusEvent, Owner, Task, TaskInput, TaskOutput, WorkerInnerOutput}; - -use crate::tasks::{ - controller_plane::{self, ControllerPlaneTask}, - data_plane, SdnChannel, SdnEvent, SdnExtIn, SdnExtOut, SdnSpawnCfg, -}; - -/// -/// -/// This function will convert the input from SDN into Plane task input. -/// It only accept bus events from the SDN task. -/// -pub fn convert_input<'a, SC, TC, TW>(event: TaskInput<'a, SdnExtIn, SdnChannel, SdnEvent>) -> TaskInput<'a, ExtIn, controller_plane::ChannelIn, controller_plane::EventIn> { - match event { - TaskInput::Bus(_, SdnEvent::ControllerPlane(event)) => TaskInput::Bus((), event), - TaskInput::Ext(ext) => TaskInput::Ext(ext), - _ => panic!("Invalid input type for ControllerPlane"), - } -} - -/// -/// -/// This function will convert the output from the Plane task into the output for the SDN task. -/// It only accept bus events from the Plane task. -/// -pub fn convert_output<'a, SE, TC: Debug, TW: Debug>( - worker: u16, - event: TaskOutput, controller_plane::ChannelIn, controller_plane::ChannelOut, controller_plane::EventOut>, -) -> WorkerInnerOutput<'a, SdnExtOut, SdnChannel, SdnEvent, SdnSpawnCfg> { - match event { - TaskOutput::Ext(ext) => WorkerInnerOutput::Ext(true, ext), - TaskOutput::Bus(BusEvent::ChannelSubscribe(channel)) => WorkerInnerOutput::Task( - Owner::group(worker, ControllerPlaneTask::<(), (), (), ()>::TYPE), - TaskOutput::Bus(BusEvent::ChannelSubscribe(SdnChannel::ControllerPlane(channel))), - ), - TaskOutput::Bus(BusEvent::ChannelUnsubscribe(channel)) => WorkerInnerOutput::Task( - Owner::group(worker, ControllerPlaneTask::<(), (), (), ()>::TYPE), - TaskOutput::Bus(BusEvent::ChannelUnsubscribe(SdnChannel::ControllerPlane(channel))), - ), - TaskOutput::Bus(BusEvent::ChannelPublish(_, safe, event)) => { - let channel = if event.is_broadcast() { - log::debug!("Broadcast to all workers {:?}", event); - SdnChannel::DataPlane(data_plane::ChannelIn::Broadcast) - } else { - SdnChannel::DataPlane(data_plane::ChannelIn::Worker(0)) - }; - - WorkerInnerOutput::Task( - Owner::group(worker, ControllerPlaneTask::<(), (), (), ()>::TYPE), - TaskOutput::Bus(BusEvent::ChannelPublish(channel, safe, SdnEvent::DataPlane(event))), - ) - } - _ => panic!("Invalid output type from ControllerPlane"), - } -} diff --git a/packages/runner/src/tasks/event_convert/data_plane.rs b/packages/runner/src/tasks/event_convert/data_plane.rs deleted file mode 100644 index 7abf38d1..00000000 --- a/packages/runner/src/tasks/event_convert/data_plane.rs +++ /dev/null @@ -1,48 +0,0 @@ -use atm0s_sdn_network::ExtOut; -use sans_io_runtime::{bus::BusEvent, Owner, Task, TaskInput, TaskOutput, WorkerInnerOutput}; - -use crate::tasks::{ - data_plane::{self, DataPlaneTask}, - SdnChannel, SdnEvent, SdnExtIn, SdnExtOut, SdnSpawnCfg, -}; - -/// -/// -/// This function will convert the input from SDN into Plane task input. -/// It only accept bus events from the SDN task. -/// -pub fn convert_input<'a, SC, TC, TW>(event: TaskInput<'a, SdnExtIn, SdnChannel, SdnEvent>) -> TaskInput<'a, (), data_plane::ChannelIn, data_plane::EventIn> { - match event { - TaskInput::Bus(SdnChannel::DataPlane(channel), SdnEvent::DataPlane(event)) => TaskInput::Bus(channel, event), - TaskInput::Net(event) => TaskInput::Net(event), - _ => panic!("Invalid input type for DataPlane"), - } -} - -/// -/// -/// This function will convert the output from the Plane task into the output for the SDN task. -/// It only accept bus events from the Plane task. -/// -pub fn convert_output<'a, SE, TC, TW>( - worker: u16, - event: TaskOutput<'a, ExtOut, data_plane::ChannelIn, data_plane::ChannelOut, data_plane::EventOut>, -) -> WorkerInnerOutput<'a, SdnExtOut, SdnChannel, SdnEvent, SdnSpawnCfg> { - match event { - TaskOutput::Ext(ext) => WorkerInnerOutput::Ext(true, ext), - TaskOutput::Bus(BusEvent::ChannelSubscribe(channel)) => WorkerInnerOutput::Task( - Owner::group(worker, DataPlaneTask::<(), (), (), ()>::TYPE), - TaskOutput::Bus(BusEvent::ChannelSubscribe(SdnChannel::DataPlane(channel))), - ), - TaskOutput::Bus(BusEvent::ChannelUnsubscribe(channel)) => WorkerInnerOutput::Task( - Owner::group(worker, DataPlaneTask::<(), (), (), ()>::TYPE), - TaskOutput::Bus(BusEvent::ChannelUnsubscribe(SdnChannel::DataPlane(channel))), - ), - TaskOutput::Bus(BusEvent::ChannelPublish(_, safe, event)) => WorkerInnerOutput::Task( - Owner::group(worker, DataPlaneTask::<(), (), (), ()>::TYPE), - TaskOutput::Bus(BusEvent::ChannelPublish(SdnChannel::ControllerPlane(()), safe, SdnEvent::ControllerPlane(event))), - ), - TaskOutput::Net(out) => WorkerInnerOutput::Task(Owner::group(worker, DataPlaneTask::<(), (), (), ()>::TYPE), TaskOutput::Net(out)), - TaskOutput::Destroy => WorkerInnerOutput::Task(Owner::group(worker, DataPlaneTask::<(), (), (), ()>::TYPE), TaskOutput::Destroy), - } -} diff --git a/packages/runner/src/tasks/event_convert/mod.rs b/packages/runner/src/tasks/event_convert/mod.rs deleted file mode 100644 index b1899cb8..00000000 --- a/packages/runner/src/tasks/event_convert/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub(crate) mod controller_plane; -pub(crate) mod data_plane; diff --git a/packages/runner/src/tasks/mod.rs b/packages/runner/src/tasks/mod.rs deleted file mode 100644 index d00c8779..00000000 --- a/packages/runner/src/tasks/mod.rs +++ /dev/null @@ -1,247 +0,0 @@ -mod controller_plane; -mod data_plane; -mod event_convert; - -use std::{fmt::Debug, sync::Arc, time::Instant}; - -use atm0s_sdn_identity::NodeId; -use atm0s_sdn_network::{ - base::{Authorization, HandshakeBuilder, ServiceBuilder}, - features::{FeaturesControl, FeaturesEvent}, - ExtIn, ExtOut, -}; -use atm0s_sdn_router::shadow::ShadowRouterHistory; -use sans_io_runtime::{Controller, Task, TaskInput, TaskOutput, TaskSwitcher, WorkerInner, WorkerInnerInput, WorkerInnerOutput}; - -pub use self::data_plane::history::DataWorkerHistory; -use self::{ - controller_plane::{ControllerPlaneCfg, ControllerPlaneTask}, - data_plane::{DataPlaneCfg, DataPlaneTask}, -}; - -pub type SdnController = Controller, SdnExtOut, SdnSpawnCfg, SdnChannel, SdnEvent, 1024>; - -pub type SdnExtIn = ExtIn; -pub type SdnExtOut = ExtOut; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum SdnChannel { - ControllerPlane(controller_plane::ChannelIn), - DataPlane(data_plane::ChannelIn), -} - -#[derive(Debug, Clone)] -pub enum SdnEvent { - ControllerPlane(controller_plane::EventIn), - DataPlane(data_plane::EventIn), -} - -pub struct ControllerCfg { - pub session: u64, - pub auth: Arc, - pub handshake: Arc, - pub tick_ms: u64, - #[cfg(feature = "vpn")] - pub vpn_tun_device: Option, -} - -pub struct SdnInnerCfg { - pub node_id: NodeId, - pub udp_port: u16, - pub controller: Option, - pub services: Vec>>, - pub history: Arc, - #[cfg(feature = "vpn")] - pub vpn_tun_fd: Option, -} - -pub struct SdnSpawnCfg {} - -enum State { - Running, - Shutdowning, - Shutdowned, -} - -pub struct SdnWorkerInner { - worker: u16, - controller: Option>, - data: DataPlaneTask, - switcher: TaskSwitcher, - state: State, -} - -impl SdnWorkerInner { - fn convert_controller_output<'a>( - &mut self, - now: Instant, - event: TaskOutput<'a, ExtOut, controller_plane::ChannelIn, controller_plane::ChannelOut, controller_plane::EventOut>, - ) -> Option, SdnChannel, SdnEvent, SdnSpawnCfg>> { - match event { - TaskOutput::Destroy => { - self.state = State::Shutdowned; - log::info!("Controller plane task destroyed => will destroy data plane task"); - Some(event_convert::data_plane::convert_output(self.worker, self.data.shutdown(now)?)) - } - _ => Some(event_convert::controller_plane::convert_output(self.worker, event)), - } - } -} - -impl WorkerInner, SdnExtOut, SdnChannel, SdnEvent, SdnInnerCfg, SdnSpawnCfg> for SdnWorkerInner { - fn build(worker: u16, cfg: SdnInnerCfg) -> Self { - if let Some(controller) = cfg.controller { - log::info!("Create controller worker"); - Self { - worker, - controller: Some(ControllerPlaneTask::build(ControllerPlaneCfg { - node_id: cfg.node_id, - auth: controller.auth, - handshake: controller.handshake, - session: controller.session, - tick_ms: controller.tick_ms, - services: cfg.services.clone(), - history: cfg.history.clone(), - #[cfg(feature = "vpn")] - vpn_tun_device: controller.vpn_tun_device, - })), - data: DataPlaneTask::build(DataPlaneCfg { - worker, - node_id: cfg.node_id, - port: cfg.udp_port, - services: cfg.services, - history: cfg.history, - #[cfg(feature = "vpn")] - vpn_tun_fd: cfg.vpn_tun_fd, - }), - switcher: TaskSwitcher::new(2), - state: State::Running, - } - } else { - log::info!("Create data only worker"); - Self { - worker, - controller: None, - data: DataPlaneTask::build(DataPlaneCfg { - worker, - node_id: cfg.node_id, - port: cfg.udp_port, - services: cfg.services, - history: cfg.history, - #[cfg(feature = "vpn")] - vpn_tun_fd: cfg.vpn_tun_fd, - }), - switcher: TaskSwitcher::new(2), - state: State::Running, - } - } - } - - fn worker_index(&self) -> u16 { - self.worker - } - - fn tasks(&self) -> usize { - match self.state { - State::Running | State::Shutdowning => { - 1 + if self.controller.is_some() { - 1 - } else { - 0 - } - } - State::Shutdowned => 0, - } - } - - fn spawn(&mut self, _now: Instant, _cfg: SdnSpawnCfg) { - todo!("Spawn not implemented") - } - - fn on_tick<'a>(&mut self, now: Instant) -> Option, SdnChannel, SdnEvent, SdnSpawnCfg>> { - let s = &mut self.switcher; - loop { - match s.looper_current(now)? as u16 { - ControllerPlaneTask::<(), (), (), ()>::TYPE => { - if let Some(out) = s.looper_process(self.controller.as_mut().map(|c| c.on_tick(now)).flatten()) { - return self.convert_controller_output(now, out); - } - } - DataPlaneTask::<(), (), (), ()>::TYPE => { - if let Some(out) = s.looper_process(self.data.on_tick(now)) { - return Some(event_convert::data_plane::convert_output(self.worker, out)); - } - } - _ => panic!("unknown task type"), - } - } - } - - fn on_event<'a>( - &mut self, - now: Instant, - event: WorkerInnerInput<'a, SdnExtIn, SdnChannel, SdnEvent>, - ) -> Option, SdnChannel, SdnEvent, SdnSpawnCfg>> { - match event { - WorkerInnerInput::Task(owner, event) => match owner.group_id()? { - ControllerPlaneTask::<(), (), (), ()>::TYPE => { - let event = event_convert::controller_plane::convert_input(event); - let out = self.controller.as_mut().map(|c| c.on_event(now, event)).flatten()?; - self.switcher.queue_flag_task(ControllerPlaneTask::<(), (), (), ()>::TYPE as usize); - self.convert_controller_output(now, out) - } - DataPlaneTask::<(), (), (), ()>::TYPE => { - let event = event_convert::data_plane::convert_input(event); - let out = self.data.on_event(now, event)?; - self.switcher.queue_flag_task(DataPlaneTask::<(), (), (), ()>::TYPE as usize); - Some(event_convert::data_plane::convert_output(self.worker, out)) - } - _ => panic!("unknown task type"), - }, - WorkerInnerInput::Ext(ext) => { - let out = self.controller.as_mut().map(|c| c.on_event(now, TaskInput::Ext(ext))).flatten()?; - self.switcher.queue_flag_task(ControllerPlaneTask::<(), (), (), ()>::TYPE as usize); - Some(event_convert::controller_plane::convert_output(self.worker, out)) - } - } - } - - fn pop_output<'a>(&mut self, now: Instant) -> Option, SdnChannel, SdnEvent, SdnSpawnCfg>> { - while let Some(current) = self.switcher.queue_current() { - match current as u16 { - ControllerPlaneTask::<(), (), (), ()>::TYPE => { - let out = self.controller.as_mut().map(|c| c.pop_output(now)).flatten(); - if let Some(out) = self.switcher.queue_process(out) { - let out = self.convert_controller_output(now, out); - if out.is_some() { - return out; - } - } - } - DataPlaneTask::<(), (), (), ()>::TYPE => { - let out = self.data.pop_output(now); - if let Some(out) = self.switcher.queue_process(out) { - return Some(event_convert::data_plane::convert_output(self.worker, out)); - } - } - _ => panic!("unknown task type"), - } - } - None - } - - fn shutdown<'a>(&mut self, now: Instant) -> Option, SdnChannel, SdnEvent, SdnSpawnCfg>> { - if !matches!(self.state, State::Running) { - return None; - } - - if let Some(controller) = &mut self.controller { - self.state = State::Shutdowning; - let out = controller.shutdown(now)?; - self.convert_controller_output(now, out) - } else { - self.state = State::Shutdowned; - Some(event_convert::data_plane::convert_output(self.worker, self.data.shutdown(now)?)) - } - } -} diff --git a/packages/runner/src/worker_inner.rs b/packages/runner/src/worker_inner.rs new file mode 100644 index 00000000..dc19a5c9 --- /dev/null +++ b/packages/runner/src/worker_inner.rs @@ -0,0 +1,278 @@ +use std::{ + collections::VecDeque, + fmt::Debug, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + sync::Arc, + time::Instant, +}; + +use atm0s_sdn_identity::NodeId; +use atm0s_sdn_network::{ + base::{Authorization, HandshakeBuilder, ServiceBuilder}, + controller_plane::ControllerPlaneCfg, + data_plane::{DataPlaneCfg, NetInput, NetOutput}, + features::{FeaturesControl, FeaturesEvent}, + worker::{SdnWorker, SdnWorkerBusEvent, SdnWorkerCfg, SdnWorkerInput, SdnWorkerOutput}, + ExtIn, ExtOut, +}; +use atm0s_sdn_router::shadow::ShadowRouterHistory; +use rand::rngs::ThreadRng; +use sans_io_runtime::{ + backend::{BackendIncoming, BackendOutgoing}, + BusChannelControl, BusControl, BusEvent, Controller, WorkerInner, WorkerInnerInput, WorkerInnerOutput, +}; + +use crate::time::TimePivot; + +pub type SdnController = Controller, SdnExtOut, SdnSpawnCfg, SdnChannel, SdnEvent, 1024>; + +pub type SdnExtIn = ExtIn; +pub type SdnExtOut = ExtOut; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct SdnOwner; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum SdnChannel { + Controller, + Worker(u16), +} + +pub type SdnEvent = SdnWorkerBusEvent; + +pub struct ControllerCfg { + pub session: u64, + pub auth: Arc, + pub handshake: Arc, + pub tick_ms: u64, + #[cfg(feature = "vpn")] + pub vpn_tun_device: Option, +} + +pub struct SdnInnerCfg { + pub node_id: NodeId, + pub tick_ms: u64, + pub udp_port: u16, + pub controller: Option, + pub services: Vec>>, + pub history: Arc, + #[cfg(feature = "vpn")] + pub vpn_tun_fd: Option, +} + +pub type SdnSpawnCfg = (); + +enum State { + Running, + Shutdowning, + Shutdowned, +} + +pub struct SdnWorkerInner { + worker: u16, + worker_inner: SdnWorker, + state: State, + timer: TimePivot, + #[cfg(feature = "vpn")] + _vpn_tun_device: Option, + udp_backend_slot: Option, + #[cfg(feature = "vpn")] + tun_backend_slot: Option, + queue: VecDeque, SdnChannel, SdnEvent, SdnSpawnCfg>>, +} + +impl SdnWorkerInner { + fn convert_output<'a>( + &mut self, + now_ms: u64, + event: SdnWorkerOutput<'a, SC, SE, TC, TW>, + ) -> Option, SdnChannel, SdnEvent, SdnSpawnCfg>> { + match event { + SdnWorkerOutput::Ext(ext) => Some(WorkerInnerOutput::Ext(true, ext)), + SdnWorkerOutput::ExtWorker(_) => { + panic!("should not have ExtWorker with standalone node") + } + SdnWorkerOutput::Net(net) => { + let out = match net { + NetOutput::UdpPacket(dest, data) => BackendOutgoing::UdpPacket { + slot: self.udp_backend_slot.expect("Should have backend slot"), + to: dest, + data, + }, + NetOutput::UdpPackets(dests, data) => BackendOutgoing::UdpPackets { + slot: self.udp_backend_slot.expect("Should have backend slot"), + to: dests, + data, + }, + #[cfg(feature = "vpn")] + NetOutput::TunPacket(data) => BackendOutgoing::TunPacket { + slot: self.tun_backend_slot.expect("should have tun"), + data, + }, + }; + Some(WorkerInnerOutput::Net(SdnOwner, out)) + } + SdnWorkerOutput::Bus(event) => match &event { + SdnWorkerBusEvent::Control(..) => Some(WorkerInnerOutput::Bus(BusControl::Channel(SdnOwner, BusChannelControl::Publish(SdnChannel::Controller, true, event)))), + SdnWorkerBusEvent::Workers(..) => Some(WorkerInnerOutput::Bus(BusControl::Broadcast(true, event))), + SdnWorkerBusEvent::Worker(worker, _msg) => Some(WorkerInnerOutput::Bus(BusControl::Channel( + SdnOwner, + BusChannelControl::Publish(SdnChannel::Worker(*worker), true, event), + ))), + }, + SdnWorkerOutput::ShutdownResponse => { + self.state = State::Shutdowned; + Some(WorkerInnerOutput::Destroy(SdnOwner)) + } + SdnWorkerOutput::Continue => { + //we need to continue pop for continue gather output + let out = self.worker_inner.pop_output(now_ms)?; + self.convert_output(now_ms, out) + } + } + } +} + +impl WorkerInner, SdnExtOut, SdnChannel, SdnEvent, SdnInnerCfg, SdnSpawnCfg> + for SdnWorkerInner +{ + fn build(worker: u16, cfg: SdnInnerCfg) -> Self { + let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), cfg.udp_port)); + let mut queue = VecDeque::from([ + WorkerInnerOutput::Bus(BusControl::Channel(SdnOwner, BusChannelControl::Subscribe(SdnChannel::Worker(worker)))), + WorkerInnerOutput::Net(SdnOwner, BackendOutgoing::UdpListen { addr, reuse: true }), + ]); + #[cfg(feature = "vpn")] + if let Some(fd) = cfg.vpn_tun_fd { + queue.push_back(WorkerInnerOutput::Net(SdnOwner, BackendOutgoing::TunBind { fd })); + } + if let Some(controller) = cfg.controller { + queue.push_back(WorkerInnerOutput::Bus(BusControl::Channel(SdnOwner, BusChannelControl::Subscribe(SdnChannel::Controller)))); + log::info!("Create controller worker"); + Self { + worker, + worker_inner: SdnWorker::new(SdnWorkerCfg { + node_id: cfg.node_id, + tick_ms: cfg.tick_ms, + controller: Some(ControllerPlaneCfg { + authorization: controller.auth, + handshake_builder: controller.handshake, + session: controller.session, + random: Box::new(ThreadRng::default()), + services: cfg.services.clone(), + }), + data: DataPlaneCfg { + worker_id: worker, + services: cfg.services, + history: cfg.history, + }, + }), + timer: TimePivot::build(), + #[cfg(feature = "vpn")] + _vpn_tun_device: controller.vpn_tun_device, + state: State::Running, + queue, + udp_backend_slot: None, + #[cfg(feature = "vpn")] + tun_backend_slot: None, + } + } else { + log::info!("Create data only worker"); + Self { + worker, + worker_inner: SdnWorker::new(SdnWorkerCfg { + node_id: cfg.node_id, + tick_ms: cfg.tick_ms, + controller: None, + data: DataPlaneCfg { + worker_id: worker, + services: cfg.services, + history: cfg.history, + }, + }), + timer: TimePivot::build(), + #[cfg(feature = "vpn")] + _vpn_tun_device: None, + state: State::Running, + queue, + udp_backend_slot: None, + #[cfg(feature = "vpn")] + tun_backend_slot: None, + } + } + } + + fn worker_index(&self) -> u16 { + self.worker + } + + fn tasks(&self) -> usize { + self.worker_inner.tasks() + } + + fn spawn(&mut self, _now: Instant, _cfg: SdnSpawnCfg) { + todo!("Spawn not implemented") + } + + fn on_tick<'a>(&mut self, now: Instant) -> Option, SdnChannel, SdnEvent, SdnSpawnCfg>> { + if let Some(e) = self.queue.pop_front() { + return Some(e); + } + let now_ms = self.timer.timestamp_ms(now); + let out = self.worker_inner.on_tick(now_ms)?; + self.convert_output(now_ms, out) + } + + fn on_event<'a>( + &mut self, + now: Instant, + event: WorkerInnerInput<'a, SdnOwner, SdnExtIn, SdnChannel, SdnEvent>, + ) -> Option, SdnChannel, SdnEvent, SdnSpawnCfg>> { + let now_ms = self.timer.timestamp_ms(now); + let out = match event { + WorkerInnerInput::Net(_, event) => match event { + BackendIncoming::UdpListenResult { bind: _, result } => { + self.udp_backend_slot = Some(result.expect("Should have slot").1); + None + } + BackendIncoming::UdpPacket { slot: _, from, data } => self.worker_inner.on_event(now_ms, SdnWorkerInput::Net(NetInput::UdpPacket(from, data))), + #[cfg(feature = "vpn")] + BackendIncoming::TunBindResult { result } => { + self.tun_backend_slot = Some(result.expect("Should have slot")); + None + } + #[cfg(feature = "vpn")] + BackendIncoming::TunPacket { slot: _, data } => self.worker_inner.on_event(now_ms, SdnWorkerInput::Net(NetInput::TunPacket(data))), + }, + WorkerInnerInput::Bus(event) => match event { + BusEvent::Broadcast(_from_worker, msg) => self.worker_inner.on_event(now_ms, SdnWorkerInput::Bus(msg)), + BusEvent::Channel(_, _, msg) => self.worker_inner.on_event(now_ms, SdnWorkerInput::Bus(msg)), + }, + WorkerInnerInput::Ext(ext) => { + log::info!("on ext event"); + self.worker_inner.on_event(now_ms, SdnWorkerInput::Ext(ext)) + } + }; + self.convert_output(now_ms, out?) + } + + fn pop_output<'a>(&mut self, now: Instant) -> Option, SdnChannel, SdnEvent, SdnSpawnCfg>> { + if let Some(e) = self.queue.pop_front() { + return Some(e); + } + let now_ms = self.timer.timestamp_ms(now); + let out = self.worker_inner.pop_output(now_ms)?; + self.convert_output(now_ms, out) + } + + fn shutdown<'a>(&mut self, now: Instant) -> Option, SdnChannel, SdnEvent, SdnSpawnCfg>> { + if !matches!(self.state, State::Running) { + return None; + } + + let now_ms = self.timer.timestamp_ms(now); + self.state = State::Shutdowning; + let out = self.worker_inner.on_event(now_ms, SdnWorkerInput::ShutdownRequest)?; + self.convert_output(now_ms, out) + } +} diff --git a/packages/runner/tests/feature_dht_kv.rs b/packages/runner/tests/feature_dht_kv.rs new file mode 100644 index 00000000..3a12f936 --- /dev/null +++ b/packages/runner/tests/feature_dht_kv.rs @@ -0,0 +1,112 @@ +use std::time::Duration; + +use atm0s_sdn::{ + features::{ + dht_kv::{self, MapControl, MapEvent}, + FeaturesControl, FeaturesEvent, + }, + secure::StaticKeyAuthorization, + services::visualization, + NodeAddr, NodeId, SdnBuilder, SdnController, SdnControllerUtils, SdnExtOut, SdnOwner, +}; +use sans_io_runtime::backend::PollingBackend; + +type SC = visualization::Control; +type SE = visualization::Event; +type TC = (); +type TW = (); + +fn process(nodes: &mut [&mut SdnController], timeout_ms: u64) { + let mut count = 0; + while count < timeout_ms / 10 { + std::thread::sleep(Duration::from_millis(10)); + count += 1; + for node in nodes.into_iter() { + if node.process().is_none() { + panic!("Node is shutdown"); + } + } + } +} + +fn expect_event(node: &mut SdnController, expected: dht_kv::Event) { + match node.pop_event() { + Some(SdnExtOut::FeaturesEvent(FeaturesEvent::DhtKv(event))) => { + assert_eq!(event, expected); + return; + } + Some(event) => { + panic!("Unexpected event: {:?}", event) + } + None => { + panic!("No event") + } + } +} + +fn build_node(node_id: NodeId, udp_port: u16) -> (SdnController, NodeAddr) { + let mut builder = SdnBuilder::::new(node_id, udp_port, vec![]); + builder.set_authorization(StaticKeyAuthorization::new("password-here")); + let node_addr = builder.node_addr(); + let node = builder.build::>(2); + (node, node_addr) +} + +#[test] +fn test_single_node() { + let (mut node, _node_addr) = build_node(1, 10000); + std::thread::sleep(Duration::from_millis(100)); + node.feature_control(FeaturesControl::DhtKv(dht_kv::Control::MapCmd(1000.into(), MapControl::Sub))); + process(&mut [&mut node], 100); + expect_event(&mut node, dht_kv::Event::MapEvent(1000.into(), MapEvent::OnRelaySelected(1))); + + node.feature_control(FeaturesControl::DhtKv(dht_kv::Control::MapCmd(1000.into(), MapControl::Set(2000.into(), vec![1, 2, 3])))); + process(&mut [&mut node], 100); + expect_event(&mut node, dht_kv::Event::MapEvent(1000.into(), MapEvent::OnSet(2000.into(), 1, vec![1, 2, 3]))); +} + +#[test] +fn test_two_nodes() { + let node1_id = 1; + let node2_id = 2; + let (mut node1, node_addr1) = build_node(node1_id, 11000); + let (mut node2, _node_addr2) = build_node(node2_id, 11001); + + node2.connect_to(node_addr1); + + process(&mut [&mut node1, &mut node2], 100); + log::info!("sending map cmd Sub"); + node1.feature_control(FeaturesControl::DhtKv(dht_kv::Control::MapCmd(1000.into(), MapControl::Sub))); + process(&mut [&mut node1, &mut node2], 100); + expect_event(&mut node1, dht_kv::Event::MapEvent(1000.into(), MapEvent::OnRelaySelected(node1_id))); + + node2.feature_control(FeaturesControl::DhtKv(dht_kv::Control::MapCmd(1000.into(), MapControl::Set(2000.into(), vec![1, 2, 3])))); + process(&mut [&mut node1, &mut node2], 100); + + expect_event(&mut node1, dht_kv::Event::MapEvent(1000.into(), MapEvent::OnSet(2000.into(), node2_id, vec![1, 2, 3]))); +} + +#[test] +fn test_three_nodes() { + let node1_id = 1; + let node2_id = 2; + let node3_id = 3; + let (mut node1, node_addr1) = build_node(node1_id, 12000); + let (mut node2, node_addr2) = build_node(node2_id, 12001); + let (mut node3, _node_addr3) = build_node(node3_id, 12002); + + node2.connect_to(node_addr1); + node3.connect_to(node_addr2); + + process(&mut [&mut node1, &mut node2, &mut node3], 100); + log::info!("sending map cmd Sub"); + node2.feature_control(FeaturesControl::DhtKv(dht_kv::Control::MapCmd(1000.into(), MapControl::Sub))); + process(&mut [&mut node1, &mut node2, &mut node3], 100); + + expect_event(&mut node2, dht_kv::Event::MapEvent(1000.into(), MapEvent::OnRelaySelected(node1_id))); + + node3.feature_control(FeaturesControl::DhtKv(dht_kv::Control::MapCmd(1000.into(), MapControl::Set(2000.into(), vec![1, 2, 3])))); + process(&mut [&mut node1, &mut node2, &mut node3], 100); + + expect_event(&mut node2, dht_kv::Event::MapEvent(1000.into(), MapEvent::OnSet(2000.into(), node3_id, vec![1, 2, 3]))); +}