diff --git a/examples/angular-todomvc/src/main.rs b/examples/angular-todomvc/src/main.rs index 87880569..fa9be158 100644 --- a/examples/angular-todomvc/src/main.rs +++ b/examples/angular-todomvc/src/main.rs @@ -4,7 +4,7 @@ use axum::Server; use serde::{Deserialize, Serialize}; use socketioxide::{ - extract::{Data, SocketRef}, + extract::{Data, SocketRef, State}, SocketIo, }; use tower::ServiceBuilder; @@ -12,8 +12,6 @@ use tower_http::{cors::CorsLayer, services::ServeDir}; use tracing::{error, info}; use tracing_subscriber::FmtSubscriber; -static TODOS: Mutex> = Mutex::new(vec![]); - #[derive(Debug, Clone, Serialize, Deserialize)] struct Todo { completed: bool, @@ -21,6 +19,8 @@ struct Todo { title: String, } +type Todos = Mutex>; + #[tokio::main] async fn main() -> Result<(), Box> { let subscriber = FmtSubscriber::new(); @@ -29,12 +29,13 @@ async fn main() -> Result<(), Box> { info!("Starting server"); - let (layer, io) = SocketIo::new_layer(); + let todos: Todos = Mutex::new(vec![]); + let (layer, io) = SocketIo::builder().with_state(todos).build_layer(); - io.ns("/", |s: SocketRef| { + io.ns("/", |s: SocketRef, todos: State| { info!("New connection: {}", s.id); - let todos = TODOS.lock().unwrap().clone(); + let todos = todos.lock().unwrap().clone(); // Because variadic args are not supported, array arguments are flattened. // Therefore to send a json array (required for the todomvc app) we need to wrap it in another array. @@ -42,10 +43,10 @@ async fn main() -> Result<(), Box> { s.on( "update-store", - |s: SocketRef, Data::>(new_todos)| { + |s: SocketRef, Data::>(new_todos), todos: State| { info!("Received update-store event: {:?}", new_todos); - let mut todos = TODOS.lock().unwrap(); + let mut todos = todos.lock().unwrap(); todos.clear(); todos.extend_from_slice(&new_todos); diff --git a/examples/basic-crud-application/src/handlers/todo.rs b/examples/basic-crud-application/src/handlers/todo.rs index 2e9d212f..b7f6a5a2 100644 --- a/examples/basic-crud-application/src/handlers/todo.rs +++ b/examples/basic-crud-application/src/handlers/todo.rs @@ -1,10 +1,7 @@ -use std::{ - collections::HashMap, - sync::{OnceLock, RwLock}, -}; +use std::{collections::HashMap, sync::RwLock}; use serde::{Deserialize, Serialize}; -use socketioxide::extract::{AckSender, Data, SocketRef}; +use socketioxide::extract::{AckSender, Data, SocketRef, State}; use tracing::info; use uuid::Uuid; @@ -24,16 +21,30 @@ pub struct PartialTodo { title: String, } -static TODOS: OnceLock>> = OnceLock::new(); -fn get_store() -> &'static RwLock> { - TODOS.get_or_init(|| RwLock::new(HashMap::new())) +#[derive(Default)] +pub struct Todos(pub RwLock>); +impl Todos { + pub fn insert(&self, id: Uuid, todo: Todo) { + self.0.write().unwrap().insert(id, todo); + } + + pub fn get(&self, id: &Uuid) -> Option { + self.0.read().unwrap().get(id).cloned() + } + + pub fn remove(&self, id: &Uuid) -> Option { + self.0.write().unwrap().remove(id) + } + pub fn values(&self) -> Vec { + self.0.read().unwrap().values().cloned().collect() + } } -pub fn create(s: SocketRef, Data(data): Data, ack: AckSender) { +pub fn create(s: SocketRef, Data(data): Data, ack: AckSender, todos: State) { let id = Uuid::new_v4(); let todo = Todo { id, inner: data }; - get_store().write().unwrap().insert(id, todo.clone()); + todos.insert(id, todo.clone()); let res: Response<_> = id.into(); ack.send(res).ok(); @@ -41,25 +52,27 @@ pub fn create(s: SocketRef, Data(data): Data, ack: AckSender) { s.broadcast().emit("todo:created", todo).ok(); } -pub async fn read(Data(id): Data, ack: AckSender) { - let todos = get_store().read().unwrap(); - +pub async fn read(Data(id): Data, ack: AckSender, todos: State) { let todo = todos.get(&id).ok_or(Error::NotFound); ack.send(todo).ok(); } -pub async fn update(s: SocketRef, Data(data): Data, ack: AckSender) { - let mut todos = get_store().write().unwrap(); - let res = todos.get_mut(&data.id).ok_or(Error::NotFound).map(|todo| { - todo.inner = data.inner.clone(); - s.broadcast().emit("todo:updated", data).ok(); - }); +pub async fn update(s: SocketRef, Data(data): Data, ack: AckSender, todos: State) { + let res = todos + .0 + .write() + .unwrap() + .get_mut(&data.id) + .ok_or(Error::NotFound) + .map(|todo| { + todo.inner = data.inner.clone(); + s.broadcast().emit("todo:updated", data).ok(); + }); ack.send(res).ok(); } -pub async fn delete(s: SocketRef, Data(id): Data, ack: AckSender) { - let mut todos = get_store().write().unwrap(); +pub async fn delete(s: SocketRef, Data(id): Data, ack: AckSender, todos: State) { let res = todos.remove(&id).ok_or(Error::NotFound).map(|_| { s.broadcast().emit("todo:deleted", id).ok(); }); @@ -67,9 +80,8 @@ pub async fn delete(s: SocketRef, Data(id): Data, ack: AckSender) { ack.send(res).ok(); } -pub async fn list(ack: AckSender) { - let todos = get_store().read().unwrap(); - let res: Response<_> = todos.values().cloned().collect::>().into(); +pub async fn list(ack: AckSender, todos: State) { + let res: Response<_> = todos.values().into(); info!("Sending todos: {:?}", res); ack.send(res).ok(); } diff --git a/examples/basic-crud-application/src/main.rs b/examples/basic-crud-application/src/main.rs index 41075bd9..7534b20b 100644 --- a/examples/basic-crud-application/src/main.rs +++ b/examples/basic-crud-application/src/main.rs @@ -6,6 +6,8 @@ use tower_http::{cors::CorsLayer, services::ServeDir}; use tracing::{error, info}; use tracing_subscriber::FmtSubscriber; +use crate::handlers::todo::Todos; + mod handlers; #[tokio::main] @@ -16,7 +18,9 @@ async fn main() -> Result<(), Box> { info!("Starting server"); - let (layer, io) = SocketIo::new_layer(); + let (layer, io) = SocketIo::builder() + .with_state(Todos::default()) + .build_layer(); io.ns("/", |s: SocketRef| { s.on("todo:create", handlers::todo::create); diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index 333481c1..4478e8e6 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -10,6 +10,7 @@ use engineioxide::sid::Sid; use tokio::sync::oneshot; use crate::adapter::Adapter; +use crate::extract::StateCell; use crate::handler::ConnectHandler; use crate::ProtocolVersion; use crate::{ @@ -23,13 +24,17 @@ use crate::{ pub struct Client { pub(crate) config: Arc, ns: RwLock, Arc>>>, + + state: StateCell, } impl Client { - pub fn new(config: Arc) -> Self { + pub fn new(config: Arc, state: StateCell) -> Self { Self { config, ns: RwLock::new(HashMap::new()), + + state, } } @@ -45,7 +50,7 @@ impl Client { let sid = esocket.id; if let Some(ns) = self.get_ns(ns_path) { - ns.connect(sid, esocket.clone(), auth, self.config.clone())?; + ns.connect(sid, esocket.clone(), auth, self.config.clone(), &self.state)?; // cancel the connect timeout task for v5 if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() { @@ -73,7 +78,7 @@ impl Client { /// Propagate a packet to a its target namespace fn sock_propagate_packet(&self, packet: Packet<'_>, sid: Sid) -> Result<(), Error> { if let Some(ns) = self.get_ns(&packet.ns) { - ns.recv(sid, packet.inner) + ns.recv(sid, packet.inner, &self.state) } else { #[cfg(feature = "tracing")] tracing::debug!("invalid namespace requested: {}", packet.ns); diff --git a/socketioxide/src/handler/connect.rs b/socketioxide/src/handler/connect.rs index 7349abfc..2771bdcb 100644 --- a/socketioxide/src/handler/connect.rs +++ b/socketioxide/src/handler/connect.rs @@ -56,6 +56,7 @@ use std::sync::Arc; use futures::Future; +use crate::extract::StateCell; use crate::{adapter::Adapter, socket::Socket}; use super::MakeErasedHandler; @@ -63,7 +64,7 @@ use super::MakeErasedHandler; /// A Type Erased [`ConnectHandler`] so it can be stored in a HashMap pub(crate) type BoxedConnectHandler = Box>; pub(crate) trait ErasedConnectHandler: Send + Sync + 'static { - fn call(&self, s: Arc>, auth: Option); + fn call(&self, s: Arc>, auth: Option, state: &StateCell); } impl MakeErasedHandler @@ -82,8 +83,8 @@ where T: Send + Sync + 'static, { #[inline(always)] - fn call(&self, s: Arc>, auth: Option) { - self.handler.call(s, auth); + fn call(&self, s: Arc>, auth: Option, state: &StateCell) { + self.handler.call(s, auth, state); } } @@ -99,7 +100,11 @@ pub trait FromConnectParts: Sized { /// Extract the arguments from the connect event. /// If it fails, the handler is not called - fn from_connect_parts(s: &Arc>, auth: &Option) -> Result; + fn from_connect_parts( + s: &Arc>, + auth: &Option, + state: &StateCell, + ) -> Result; } /// Define a handler for the connect event. @@ -109,7 +114,7 @@ pub trait FromConnectParts: Sized { /// * See the [`extract`](super::extract) module doc for more details on available extractors. pub trait ConnectHandler: Send + Sync + 'static { /// Call the handler with the given arguments. - fn call(&self, s: Arc>, auth: Option); + fn call(&self, s: Arc>, auth: Option, state: &StateCell); #[doc(hidden)] fn phantom(&self) -> std::marker::PhantomData { @@ -136,15 +141,20 @@ macro_rules! impl_handler_async { A: Adapter, $( $ty: FromConnectParts + Send, )* { - fn call(&self, s: Arc>, auth: Option) { + fn call( + &self, + s: Arc>, + auth: Option, + state: &StateCell) + { $( - let $ty = match $ty::from_connect_parts(&s, &auth) { - Ok(v) => v, - Err(_e) => { - #[cfg(feature = "tracing")] - tracing::error!("Error while extracting data: {}", _e); - return; - }, + let $ty = match $ty::from_connect_parts(&s, &auth, state) { + Ok(v) => v, + Err(_e) => { + #[cfg(feature = "tracing")] + tracing::error!("Error while extracting data: {}", _e); + return; + }, }; )* @@ -163,13 +173,18 @@ macro_rules! impl_handler { #[allow(non_snake_case, unused)] impl ConnectHandler for F where - F: FnOnce($($ty,)*) + Send + Sync + Clone + 'static, + F: FnOnce($($ty,)*) + Send + Sync + Clone + 'static, A: Adapter, $( $ty: FromConnectParts + Send, )* { - fn call(&self, s: Arc>, auth: Option) { + fn call( + &self, + s: Arc>, + auth: Option, + state: &StateCell) + { $( - let $ty = match $ty::from_connect_parts(&s, &auth) { + let $ty = match $ty::from_connect_parts(&s, &auth, state) { Ok(v) => v, Err(_e) => { #[cfg(feature = "tracing")] diff --git a/socketioxide/src/handler/extract.rs b/socketioxide/src/handler/extract.rs index a4e720a1..697ef03c 100644 --- a/socketioxide/src/handler/extract.rs +++ b/socketioxide/src/handler/extract.rs @@ -116,7 +116,11 @@ where A: Adapter, { type Error = serde_json::Error; - fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { + fn from_connect_parts( + _: &Arc>, + auth: &Option, + _: &StateCell, + ) -> Result { auth.as_ref() .map(|a| serde_json::from_str::(a)) .unwrap_or(serde_json::from_str::("{}")) @@ -134,6 +138,7 @@ where v: &mut serde_json::Value, _: &mut Vec>, _: &Option, + _: &StateCell, ) -> Result { upwrap_array(v); serde_json::from_value(v.clone()).map(Data) @@ -149,7 +154,11 @@ where A: Adapter, { type Error = Infallible; - fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { + fn from_connect_parts( + _: &Arc>, + auth: &Option, + _: &StateCell, + ) -> Result { let v: Result = auth .as_ref() .map(|a| serde_json::from_str(a)) @@ -168,6 +177,7 @@ where v: &mut serde_json::Value, _: &mut Vec>, _: &Option, + _: &StateCell, ) -> Result { upwrap_array(v); Ok(TryData(serde_json::from_value(v.clone()))) @@ -178,7 +188,11 @@ pub struct SocketRef(Arc>); impl FromConnectParts for SocketRef { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts( + s: &Arc>, + _: &Option, + _: &StateCell, + ) -> Result { Ok(SocketRef(s.clone())) } } @@ -189,6 +203,7 @@ impl FromMessageParts for SocketRef { _: &mut serde_json::Value, _: &mut Vec>, _: &Option, + _: &StateCell, ) -> Result { Ok(SocketRef(s.clone())) } @@ -233,6 +248,7 @@ impl FromMessage for Bin { _: serde_json::Value, bin: Vec>, _: Option, + _: &StateCell, ) -> Result { Ok(Bin(bin)) } @@ -253,6 +269,7 @@ impl FromMessageParts for AckSender { _: &mut serde_json::Value, _: &mut Vec>, ack_id: &Option, + _: &StateCell, ) -> Result { Ok(Self::new(s.clone(), *ack_id)) } @@ -291,7 +308,11 @@ impl AckSender { impl FromConnectParts for crate::ProtocolVersion { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts( + s: &Arc>, + _: &Option, + _: &StateCell, + ) -> Result { Ok(s.protocol()) } } @@ -302,6 +323,7 @@ impl FromMessageParts for crate::ProtocolVersion { _: &mut serde_json::Value, _: &mut Vec>, _: &Option, + _: &StateCell, ) -> Result { Ok(s.protocol()) } @@ -315,7 +337,11 @@ impl FromDisconnectParts for crate::ProtocolVersion { impl FromConnectParts for crate::TransportType { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts( + s: &Arc>, + _: &Option, + _: &StateCell, + ) -> Result { Ok(s.transport_type()) } } @@ -326,10 +352,12 @@ impl FromMessageParts for crate::TransportType { _: &mut serde_json::Value, _: &mut Vec>, _: &Option, + _: &StateCell, ) -> Result { Ok(s.transport_type()) } } + impl FromDisconnectParts for crate::TransportType { type Error = Infallible; fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { @@ -346,3 +374,136 @@ impl FromDisconnectParts for DisconnectReason { Ok(reason) } } + +/// An Extractor that contains a reference to a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). +/// It implements [`std::ops::Deref`] to access the inner type so you can use it as a normal reference. +/// +/// The specified state type must be the same as the one set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). +/// If it is not the case, the handler won't be called and an error log will be print if the `tracing` feature is enabled. +/// +/// The state is shared between the entire socket.io app context. +/// +/// ### Example +/// ``` +/// # use socketioxide::{SocketIo, extract::{SocketRef, State}}; +/// # use serde::{Serialize, Deserialize}; +/// # use std::sync::atomic::AtomicUsize; +/// # use std::sync::atomic::Ordering; +/// #[derive(Default)] +/// struct MyAppData { +/// user_cnt: AtomicUsize, +/// } +/// impl MyAppData { +/// fn add_user(&self) { +/// self.user_cnt.fetch_add(1, Ordering::SeqCst); +/// } +/// fn user_cnt(&self) -> usize { +/// self.user_cnt.load(Ordering::SeqCst) +/// } +/// fn rm_user(&self) { +/// self.user_cnt.fetch_sub(1, Ordering::SeqCst); +/// } +/// } +/// let (_, io) = SocketIo::builder().with_state(MyAppData::default()).build_svc(); +/// io.ns("/", |socket: SocketRef, state: State| { +/// state.add_user(); +/// println!("User count: {:?}", state.user_cnt); +/// println!("User count: {}", state.user_cnt()); +/// }); +pub struct State { + state: StateCell, + _marker: std::marker::PhantomData, +} +/// The state must be in an `Arc` because it is impossible to have lifetime on the extracted data. +pub type StateCell = Arc; + +impl std::ops::Deref for State { + type Target = T; + fn deref(&self) -> &Self::Target { + // SAFETY: The state type is checked when the extractor is created + // TODO: use `downcast_ref_unchecked` when it is stable + unsafe { &*(&self.state as *const dyn std::any::Any as *const T) } + } +} +impl std::fmt::Debug for State { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("State") + .field("state", &self.state) + .field("_marker", &self._marker) + .finish() + } +} + +/// It was impossible to find the given state and therefore the handler won't be called. +#[derive(Debug, thiserror::Error)] +#[error("State not found")] +pub struct StateNotFound; + +impl FromConnectParts for State { + type Error = StateNotFound; + fn from_connect_parts( + _: &Arc>, + _: &Option, + state: &StateCell, + ) -> Result { + // SAFETY: This check is mandatory because later when the extractor is used, + // the state type is dereferenced without any checks + if !state.is::() { + return Err(StateNotFound); + } + Ok(State { + state: state.clone(), + _marker: std::marker::PhantomData, + }) + } +} +impl FromMessageParts for State { + type Error = StateNotFound; + fn from_message_parts( + _: &Arc>, + _: &mut Value, + _: &mut Vec>, + _: &Option, + state: &StateCell, + ) -> Result { + if !state.is::() { + return Err(StateNotFound); + } + Ok(State { + state: state.clone(), + _marker: std::marker::PhantomData, + }) + } +} + +#[cfg(test)] +mod tests { + use engineioxide::sid::Sid; + + use crate::ns::Namespace; + + use super::*; + + fn new_socket() -> Arc> { + let sid = Sid::new(); + Arc::new(Socket::new_dummy( + sid, + Namespace::::new_dummy([sid]), + )) + } + + #[test] + fn extract_state_not_found() { + struct A; + struct B; + let b = Arc::new(B) as StateCell; + State::::from_connect_parts(&new_socket(), &None, &b).unwrap_err(); + } + + #[test] + fn extract_state_found() { + struct A; + let a = Arc::new(A) as StateCell; + State::::from_connect_parts(&new_socket(), &None, &a).unwrap(); + } +} diff --git a/socketioxide/src/handler/message.rs b/socketioxide/src/handler/message.rs index a4501e14..be8609b4 100644 --- a/socketioxide/src/handler/message.rs +++ b/socketioxide/src/handler/message.rs @@ -77,6 +77,7 @@ use futures::Future; use serde_json::Value; use crate::adapter::Adapter; +use crate::extract::StateCell; use crate::socket::Socket; use super::MakeErasedHandler; @@ -85,7 +86,14 @@ use super::MakeErasedHandler; pub(crate) type BoxedMessageHandler = Box>; pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { - fn call(&self, s: Arc>, v: Value, p: Vec>, ack_id: Option); + fn call( + &self, + s: Arc>, + v: Value, + p: Vec>, + ack_id: Option, + state: &StateCell, + ); } /// Define a handler for the connect event. @@ -101,7 +109,14 @@ pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { )] pub trait MessageHandler: Send + Sync + 'static { /// Call the handler with the given arguments - fn call(&self, s: Arc>, v: Value, p: Vec>, ack_id: Option); + fn call( + &self, + s: Arc>, + v: Value, + p: Vec>, + ack_id: Option, + state: &StateCell, + ); #[doc(hidden)] fn phantom(&self) -> std::marker::PhantomData { @@ -127,8 +142,15 @@ where A: Adapter, { #[inline(always)] - fn call(&self, s: Arc>, v: Value, p: Vec>, ack_id: Option) { - self.handler.call(s, v, p, ack_id); + fn call( + &self, + s: Arc>, + v: Value, + p: Vec>, + ack_id: Option, + state: &StateCell, + ) { + self.handler.call(s, v, p, ack_id, state); } } @@ -167,6 +189,7 @@ pub trait FromMessageParts: Sized { v: &mut Value, p: &mut Vec>, ack_id: &Option, + state: &StateCell, ) -> Result; } @@ -192,6 +215,7 @@ pub trait FromMessage: Sized { v: Value, p: Vec>, ack_id: Option, + state: &StateCell, ) -> Result; } @@ -207,8 +231,9 @@ where mut v: Value, mut p: Vec>, ack_id: Option, + state: &StateCell, ) -> Result { - Self::from_message_parts(&s, &mut v, &mut p, &ack_id) + Self::from_message_parts(&s, &mut v, &mut p, &ack_id, state) } } @@ -219,7 +244,7 @@ where Fut: Future + Send + 'static, A: Adapter, { - fn call(&self, _: Arc>, _: Value, _: Vec>, _: Option) { + fn call(&self, _: Arc>, _: Value, _: Vec>, _: Option, _: &StateCell) { let fut = (self.clone())(); tokio::spawn(fut); } @@ -231,7 +256,7 @@ where F: FnOnce() + Send + Sync + Clone + 'static, A: Adapter, { - fn call(&self, _: Arc>, _: Value, _: Vec>, _: Option) { + fn call(&self, _: Arc>, _: Value, _: Vec>, _: Option, _: &StateCell) { (self.clone())(); } } @@ -249,9 +274,16 @@ macro_rules! impl_async_handler { $( $ty: FromMessageParts + Send, )* $last: FromMessage + Send, { - fn call(&self, s: Arc>, mut v: Value, mut p: Vec>, ack_id: Option) { + fn call( + &self, + s: Arc>, + mut v: Value, + mut p: Vec>, + ack_id: Option, + state: &StateCell) + { $( - let $ty = match $ty::from_message_parts(&s, &mut v, &mut p, &ack_id) { + let $ty = match $ty::from_message_parts(&s, &mut v, &mut p, &ack_id, state) { Ok(v) => v, Err(_e) => { #[cfg(feature = "tracing")] @@ -260,7 +292,7 @@ macro_rules! impl_async_handler { }, }; )* - let last = match $last::from_message(s, v, p, ack_id) { + let last = match $last::from_message(s, v, p, ack_id, state) { Ok(v) => v, Err(_e) => { #[cfg(feature = "tracing")] @@ -287,14 +319,21 @@ macro_rules! impl_handler { $( $ty: FromMessageParts + Send, )* $last: FromMessage + Send, { - fn call(&self, s: Arc>, mut v: Value, mut p: Vec>, ack_id: Option) { + fn call( + &self, + s: Arc>, + mut v: Value, + mut p: Vec>, + ack_id: Option, + state: &StateCell) + { $( - let $ty = match $ty::from_message_parts(&s, &mut v, &mut p, &ack_id) { + let $ty = match $ty::from_message_parts(&s, &mut v, &mut p, &ack_id, state) { Ok(v) => v, Err(_) => return, }; )* - let last = match $last::from_message(s, v, p, ack_id) { + let last = match $last::from_message(s, v, p, ack_id, state) { Ok(v) => v, Err(_) => return, }; diff --git a/socketioxide/src/io.rs b/socketioxide/src/io.rs index faeec46b..39ca3475 100644 --- a/socketioxide/src/io.rs +++ b/socketioxide/src/io.rs @@ -11,7 +11,7 @@ use serde::de::DeserializeOwned; use crate::{ adapter::{Adapter, LocalAdapter}, client::Client, - extract::SocketRef, + extract::{SocketRef, StateCell}, handler::ConnectHandler, layer::SocketIoLayer, operators::{Operators, RoomParam}, @@ -57,6 +57,7 @@ pub struct SocketIoBuilder { config: SocketIoConfig, engine_config_builder: EngineIoConfigBuilder, adapter: std::marker::PhantomData, + state: StateCell, } impl SocketIoBuilder { @@ -66,6 +67,7 @@ impl SocketIoBuilder { config: SocketIoConfig::default(), engine_config_builder: EngineIoConfigBuilder::new().req_path("/socket.io".to_string()), adapter: std::marker::PhantomData, + state: Arc::new(()), } } @@ -156,19 +158,30 @@ impl SocketIoBuilder { /// Sets a custom [`Adapter`] for this [`SocketIoBuilder`] pub fn with_adapter(self) -> SocketIoBuilder { SocketIoBuilder { + adapter: std::marker::PhantomData, config: self.config, engine_config_builder: self.engine_config_builder, - adapter: std::marker::PhantomData, + state: self.state, } } + /// Sets a custom global state for the [`SocketIo`] instance. + /// This state will be accessible from every handler with the [`State`](crate::extract::State) extractor. + /// Only one state can be set for a [`SocketIo`] instance. Setting a new state will replace the previous one. + #[inline] + pub fn with_state(mut self, state: S) -> Self { + self.state = Arc::new(state); + self + } + /// Builds a [`SocketIoLayer`] and a [`SocketIo`] instance /// /// The layer can be used as a tower layer pub fn build_layer(mut self) -> (SocketIoLayer, SocketIo) { self.config.engine_config = self.engine_config_builder.build(); - let (layer, client) = SocketIoLayer::from_config(Arc::new(self.config)); + let (layer, client) = SocketIoLayer::from_config(Arc::new(self.config), self.state); + (layer, SocketIo(client)) } @@ -180,7 +193,8 @@ impl SocketIoBuilder { self.config.engine_config = self.engine_config_builder.build(); let (svc, client) = - SocketIoService::with_config_inner(NotFoundService, Arc::new(self.config)); + SocketIoService::with_config_inner(NotFoundService, Arc::new(self.config), self.state); + (svc, SocketIo(client)) } @@ -190,7 +204,9 @@ impl SocketIoBuilder { pub fn build_with_inner_svc(mut self, svc: S) -> (SocketIoService, SocketIo) { self.config.engine_config = self.engine_config_builder.build(); - let (svc, client) = SocketIoService::with_config_inner(svc, Arc::new(self.config)); + let (svc, client) = + SocketIoService::with_config_inner(svc, Arc::new(self.config), self.state); + (svc, SocketIo(client)) } } diff --git a/socketioxide/src/layer.rs b/socketioxide/src/layer.rs index ca7b47b6..dd370409 100644 --- a/socketioxide/src/layer.rs +++ b/socketioxide/src/layer.rs @@ -40,6 +40,7 @@ use tower::Layer; use crate::{ adapter::{Adapter, LocalAdapter}, client::Client, + extract::StateCell, service::SocketIoService, SocketIoConfig, }; @@ -58,8 +59,11 @@ impl Clone for SocketIoLayer { } impl SocketIoLayer { - pub(crate) fn from_config(config: Arc) -> (Self, Arc>) { - let client = Arc::new(Client::new(config.clone())); + pub(crate) fn from_config( + config: Arc, + state: StateCell, + ) -> (Self, Arc>) { + let client = Arc::new(Client::new(config.clone(), state)); let layer = Self { client: client.clone(), }; diff --git a/socketioxide/src/ns.rs b/socketioxide/src/ns.rs index beed0c8d..b7a5f7cd 100644 --- a/socketioxide/src/ns.rs +++ b/socketioxide/src/ns.rs @@ -7,6 +7,7 @@ use std::{ use crate::{ adapter::Adapter, errors::Error, + extract::StateCell, handler::{BoxedConnectHandler, ConnectHandler, MakeErasedHandler}, packet::{Packet, PacketData}, socket::Socket, @@ -43,6 +44,7 @@ impl Namespace { esocket: Arc>, auth: Option, config: Arc, + state: &StateCell, ) -> Result<(), serde_json::Error> { let socket: Arc> = Socket::new(sid, self.clone(), esocket.clone(), config).into(); @@ -56,7 +58,7 @@ impl Namespace { return Ok(()); } - self.handler.call(socket, auth); + self.handler.call(socket, auth, state); Ok(()) } @@ -72,11 +74,11 @@ impl Namespace { self.sockets.read().unwrap().values().any(|s| s.id == sid) } - pub fn recv(&self, sid: Sid, packet: PacketData<'_>) -> Result<(), Error> { + pub fn recv(&self, sid: Sid, packet: PacketData<'_>, state: &StateCell) -> Result<(), Error> { match packet { PacketData::Connect(_) => unreachable!("connect packets should be handled before"), PacketData::ConnectError => Err(Error::InvalidPacketType), - packet => self.get_socket(sid)?.recv(packet), + packet => self.get_socket(sid)?.recv(packet, state), } } pub fn get_socket(&self, sid: Sid) -> Result>, Error> { diff --git a/socketioxide/src/service.rs b/socketioxide/src/service.rs index 6ee6b96a..e2973f5c 100644 --- a/socketioxide/src/service.rs +++ b/socketioxide/src/service.rs @@ -37,6 +37,7 @@ use tower::Service; use crate::{ adapter::{Adapter, LocalAdapter}, client::Client, + extract::StateCell, SocketIoConfig, }; @@ -77,9 +78,10 @@ impl SocketIoService { pub(crate) fn with_config_inner( inner: S, config: Arc, + state: StateCell, ) -> (Self, Arc>) { let engine_config = config.engine_config.clone(); - let client = Arc::new(Client::new(config)); + let client = Arc::new(Client::new(config, state)); let svc = EngineIoService::with_config_inner(inner, client.clone(), engine_config); (Self { engine_svc: svc }, client) } diff --git a/socketioxide/src/socket.rs b/socketioxide/src/socket.rs index a11e198d..05d0557e 100644 --- a/socketioxide/src/socket.rs +++ b/socketioxide/src/socket.rs @@ -23,6 +23,7 @@ use crate::extensions::Extensions; use crate::{ adapter::{Adapter, LocalAdapter, Room}, errors::{AckError, Error}, + extract::StateCell, handler::{ BoxedDisconnectHandler, BoxedMessageHandler, DisconnectHandler, MakeErasedHandler, MessageHandler, @@ -593,11 +594,15 @@ impl Socket { } // Receives data from client: - pub(crate) fn recv(self: Arc, packet: PacketData<'_>) -> Result<(), Error> { + pub(crate) fn recv( + self: Arc, + packet: PacketData<'_>, + state: &StateCell, + ) -> Result<(), Error> { match packet { - PacketData::Event(e, data, ack) => self.recv_event(&e, data, ack), + PacketData::Event(e, data, ack) => self.recv_event(&e, data, ack, state), PacketData::EventAck(data, ack_id) => self.recv_ack(data, ack_id), - PacketData::BinaryEvent(e, packet, ack) => self.recv_bin_event(&e, packet, ack), + PacketData::BinaryEvent(e, packet, ack) => self.recv_bin_event(&e, packet, ack, state), PacketData::BinaryAck(packet, ack) => self.recv_bin_ack(packet, ack), PacketData::Disconnect => self .close(DisconnectReason::ClientNSDisconnect) @@ -645,9 +650,15 @@ impl Socket { self.esocket.protocol.into() } - fn recv_event(self: Arc, e: &str, data: Value, ack: Option) -> Result<(), Error> { + fn recv_event( + self: Arc, + e: &str, + data: Value, + ack: Option, + state: &StateCell, + ) -> Result<(), Error> { if let Some(handler) = self.message_handlers.read().unwrap().get(e) { - handler.call(self.clone(), data, vec![], ack); + handler.call(self.clone(), data, vec![], ack, state); } Ok(()) } @@ -657,9 +668,10 @@ impl Socket { e: &str, packet: BinaryPacket, ack: Option, + state: &StateCell, ) -> Result<(), Error> { if let Some(handler) = self.message_handlers.read().unwrap().get(e) { - handler.call(self.clone(), packet.data, packet.bin, ack); + handler.call(self.clone(), packet.data, packet.bin, ack, state); } Ok(()) }