diff --git a/src/channels.rs b/src/channels.rs index 7691b37..89f31a1 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -1,5 +1,9 @@ +#[cfg(not(target_arch = "wasm32"))] +use bevy::tasks::Task; use bevy::tasks::TaskPool; use futures_timer::Delay; +#[cfg(not(target_arch = "wasm32"))] +use std::sync::Mutex; use std::{future::Future, ops::Deref, pin::Pin, sync::Arc, time::Duration}; use turbulence::{buffer::BufferPool, runtime::Runtime}; @@ -20,10 +24,20 @@ pub struct TaskPoolRuntime(Arc); pub struct TaskPoolRuntimeInner { pool: TaskPool, + #[cfg(not(target_arch = "wasm32"))] + tasks: Mutex>>, // FIXME: cleanup finished } impl TaskPoolRuntime { pub fn new(pool: TaskPool) -> Self { + #[cfg(not(target_arch = "wasm32"))] + { + TaskPoolRuntime(Arc::new(TaskPoolRuntimeInner { + pool, + tasks: Mutex::new(Vec::new()), + })) + } + #[cfg(target_arch = "wasm32")] TaskPoolRuntime(Arc::new(TaskPoolRuntimeInner { pool })) } } @@ -41,7 +55,13 @@ impl Runtime for TaskPoolRuntime { type Sleep = Pin + Send>>; fn spawn + Send + 'static>(&self, f: F) { - self.pool.spawn(Box::pin(f)).detach(); + #[cfg(not(target_arch = "wasm32"))] + self.tasks + .lock() + .unwrap() + .push(self.pool.spawn(Box::pin(f))); + #[cfg(target_arch = "wasm32")] + self.pool.spawn(Box::pin(f)); } fn now(&self) -> Self::Instant { diff --git a/src/lib.rs b/src/lib.rs index 3033eb4..2d8454a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,11 @@ +#[cfg(not(target_arch = "wasm32"))] +use bevy::tasks::Task; use bevy::{ app::{App, CoreStage, Events, Plugin}, core::FixedTimestep, prelude::*, tasks::{IoTaskPool, TaskPool}, }; - #[cfg(not(target_arch = "wasm32"))] use crossbeam_channel::{unbounded, SendError as CrossbeamSendError, Sender}; #[cfg(not(target_arch = "wasm32"))] @@ -114,6 +115,8 @@ pub struct NetworkResource { #[cfg(not(target_arch = "wasm32"))] server_channels: Arc>, + #[cfg(not(target_arch = "wasm32"))] + listeners: Vec>, runtime: TaskPoolRuntime, packet_pool: MuxPacketPool>, @@ -200,6 +203,8 @@ impl NetworkResource { pending_connections: Arc::new(Mutex::new(Vec::new())), #[cfg(not(target_arch = "wasm32"))] server_channels: Arc::new(RwLock::new(HashMap::new())), + #[cfg(not(target_arch = "wasm32"))] + listeners: Vec::new(), runtime, packet_pool, channels_builder_fn: None, @@ -244,79 +249,76 @@ impl NetworkResource { let pending_connections = self.pending_connections.clone(); let task_pool = self.task_pool.clone(); - self.task_pool - .spawn(async move { - loop { - match server_socket.receive().await { - Ok(packet) => { - let address = packet.address(); - let message = String::from_utf8_lossy(packet.payload()); - debug!( - "Server recv <- {}:{}: {}", - address, - packet.payload().len(), - message - ); - - let needs_new_channel = match server_channels - .read() - .expect("server channels lock is poisoned") - .get(&address) - .map(|channel| { - channel.send(Ok(Packet::copy_from_slice(packet.payload()))) - }) { - Some(Ok(())) => false, - Some(Err(CrossbeamSendError(_packet))) => { - error!("Server can't send to channel, recreating"); - // If we can't send to a channel, it's disconnected. - // We need to re-create the channel and re-try sending the message. - true - } - // This is a new connection, so we need to create a channel. - None => true, - }; - - if !needs_new_channel { - continue; + self.listeners.push(self.task_pool.spawn(async move { + loop { + match server_socket.receive().await { + Ok(packet) => { + let address = packet.address(); + let message = String::from_utf8_lossy(packet.payload()); + debug!( + "Server recv <- {}:{}: {}", + address, + packet.payload().len(), + message + ); + + let needs_new_channel = match server_channels + .read() + .expect("server channels lock is poisoned") + .get(&address) + .map(|channel| { + channel.send(Ok(Packet::copy_from_slice(packet.payload()))) + }) { + Some(Ok(())) => false, + Some(Err(CrossbeamSendError(_packet))) => { + error!("Server can't send to channel, recreating"); + // If we can't send to a channel, it's disconnected. + // We need to re-create the channel and re-try sending the message. + true } + // This is a new connection, so we need to create a channel. + None => true, + }; - // We try to do a write lock only in case when a channel doesn't exist or - // has to be re-created. Trying to acquire a channel even for new - // connections is kind of a positive prediction to avoid doing a write - // lock. - let mut server_channels = server_channels - .write() - .expect("server channels lock is poisoned"); - let (packet_tx, packet_rx) = - unbounded::>(); - match packet_tx.send(Ok(Packet::copy_from_slice(packet.payload()))) { - Ok(()) => { - // It makes sense to store the channel only if it's healthy. - pending_connections.lock().unwrap().push(Box::new( - transport::ServerConnection::new( - task_pool.clone(), - packet_rx, - server_socket.get_sender(), - address, - ), - )); - server_channels.insert(address, packet_tx); - } - Err(error) => { - // This branch is unlikely to get called the second time (after - // re-creating a channel), but if for some strange reason it does, - // we'll just lose the message this time. - error!("Server Send Error (retry): {}", error); - } - } + if !needs_new_channel { + continue; } - Err(error) => { - error!("Server Receive Error: {}", error); + + // We try to do a write lock only in case when a channel doesn't exist or + // has to be re-created. Trying to acquire a channel even for new + // connections is kind of a positive prediction to avoid doing a write + // lock. + let mut server_channels = server_channels + .write() + .expect("server channels lock is poisoned"); + let (packet_tx, packet_rx) = unbounded::>(); + match packet_tx.send(Ok(Packet::copy_from_slice(packet.payload()))) { + Ok(()) => { + // It makes sense to store the channel only if it's healthy. + pending_connections.lock().unwrap().push(Box::new( + transport::ServerConnection::new( + task_pool.clone(), + packet_rx, + server_socket.get_sender(), + address, + ), + )); + server_channels.insert(address, packet_tx); + } + Err(error) => { + // This branch is unlikely to get called the second time (after + // re-creating a channel), but if for some strange reason it does, + // we'll just lose the message this time. + error!("Server Send Error (retry): {}", error); + } } } + Err(error) => { + error!("Server Receive Error: {}", error); + } } - }) - .detach(); + } + })); } pub fn connect(&mut self, socket_address: SocketAddr) { diff --git a/src/transport.rs b/src/transport.rs index 9289fb8..1965c71 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -1,3 +1,5 @@ +#[cfg(not(target_arch = "wasm32"))] +use bevy::tasks::Task; use bevy::{prelude::error, tasks::TaskPool}; use bytes::Bytes; use instant::{Duration, Instant}; @@ -117,6 +119,8 @@ pub struct ServerConnection { channels: Option, channels_rx: Option>, + #[cfg(not(target_arch = "wasm32"))] + channels_task: Option>, } #[cfg(not(target_arch = "wasm32"))] @@ -135,6 +139,7 @@ impl ServerConnection { stats: Arc::new(RwLock::new(PacketStats::default())), channels: None, channels_rx: None, + channels_task: None, } } } @@ -209,21 +214,20 @@ impl Connection for ServerConnection { let mut sender = self.sender.take().unwrap(); let client_address = self.client_address; let stats = self.stats.clone(); - self.task_pool - .spawn(async move { - loop { - let packet = channels_tx.next().await.unwrap(); - stats - .write() - .expect("stats lock poisoned") - .add_tx(packet.len()); - sender - .send(ServerPacket::new(client_address, (*packet).into())) - .await - .unwrap(); - } - }) - .detach(); + + self.channels_task = Some(self.task_pool.spawn(async move { + loop { + let packet = channels_tx.next().await.unwrap(); + stats + .write() + .expect("stats lock poisoned") + .add_tx(packet.len()); + sender + .send(ServerPacket::new(client_address, (*packet).into())) + .await + .unwrap(); + } + })); } fn channels(&mut self) -> Option<&mut MessageChannels> { @@ -244,6 +248,8 @@ pub struct ClientConnection { channels: Option, channels_rx: Option>, + #[cfg(not(target_arch = "wasm32"))] + channels_task: Option>, } impl ClientConnection { @@ -259,6 +265,8 @@ impl ClientConnection { stats: Arc::new(RwLock::new(PacketStats::default())), channels: None, channels_rx: None, + #[cfg(not(target_arch = "wasm32"))] + channels_task: None, } } } @@ -321,25 +329,31 @@ impl Connection for ClientConnection { let mut sender = self.sender.take().unwrap(); let stats = self.stats.clone(); - self.task_pool - .spawn(async move { - loop { - match channels_tx.next().await { - Some(packet) => { - stats - .write() - .expect("stats lock poisoned") - .add_tx(packet.len()); - sender.send(ClientPacket::new((*packet).into())).unwrap(); - } - None => { - error!("Channel stream Disconnected"); - return; // exit task - } + + let closure = async move { + loop { + match channels_tx.next().await { + Some(packet) => { + stats + .write() + .expect("stats lock poisoned") + .add_tx(packet.len()); + sender.send(ClientPacket::new((*packet).into())).unwrap(); + } + None => { + error!("Channel stream Disconnected"); + return; // exit task } } - }) - .detach(); + } + }; + + #[cfg(not(target_arch = "wasm32"))] + { + self.channels_task = Some(self.task_pool.spawn(closure)); + } + #[cfg(target_arch = "wasm32")] + self.task_pool.spawn(closure); } fn channels(&mut self) -> Option<&mut MessageChannels> {