Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add global state management #187

Closed
wants to merge 11 commits into from
17 changes: 9 additions & 8 deletions examples/angular-todomvc/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,23 @@ use axum::Server;

use serde::{Deserialize, Serialize};
use socketioxide::{
extract::{Data, SocketRef},
extract::{Data, SocketRef, State},
SocketIo,
};
use tower::ServiceBuilder;
use tower_http::{cors::CorsLayer, services::ServeDir};
use tracing::{error, info};
use tracing_subscriber::FmtSubscriber;

static TODOS: Mutex<Vec<Todo>> = Mutex::new(vec![]);

#[derive(Debug, Clone, Serialize, Deserialize)]
struct Todo {
completed: bool,
editing: bool,
title: String,
}

type Todos = Mutex<Vec<Todo>>;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let subscriber = FmtSubscriber::new();
Expand All @@ -29,23 +29,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

info!("Starting server");

let (layer, io) = SocketIo::new_layer();
let todos: Todos = Mutex::new(vec![]);
let (layer, io) = SocketIo::builder().with_state(todos).build_layer();

io.ns("/", |s: SocketRef| {
io.ns("/", |s: SocketRef, todos: State<Todos>| {
info!("New connection: {}", s.id);

let todos = TODOS.lock().unwrap().clone();
let todos = todos.lock().unwrap().clone();

// Because variadic args are not supported, array arguments are flattened.
// Therefore to send a json array (required for the todomvc app) we need to wrap it in another array.
s.emit("todos", [todos]).unwrap();

s.on(
"update-store",
|s: SocketRef, Data::<Vec<Todo>>(new_todos)| {
|s: SocketRef, Data::<Vec<Todo>>(new_todos), todos: State<Todos>| {
info!("Received update-store event: {:?}", new_todos);

let mut todos = TODOS.lock().unwrap();
let mut todos = todos.lock().unwrap();
todos.clear();
todos.extend_from_slice(&new_todos);

Expand Down
60 changes: 36 additions & 24 deletions examples/basic-crud-application/src/handlers/todo.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use std::{
collections::HashMap,
sync::{OnceLock, RwLock},
};
use std::{collections::HashMap, sync::RwLock};

use serde::{Deserialize, Serialize};
use socketioxide::extract::{AckSender, Data, SocketRef};
use socketioxide::extract::{AckSender, Data, SocketRef, State};
use tracing::info;
use uuid::Uuid;

Expand All @@ -24,52 +21,67 @@ pub struct PartialTodo {
title: String,
}

static TODOS: OnceLock<RwLock<HashMap<Uuid, Todo>>> = OnceLock::new();
fn get_store() -> &'static RwLock<HashMap<Uuid, Todo>> {
TODOS.get_or_init(|| RwLock::new(HashMap::new()))
#[derive(Default)]
pub struct Todos(pub RwLock<HashMap<Uuid, Todo>>);
impl Todos {
pub fn insert(&self, id: Uuid, todo: Todo) {
self.0.write().unwrap().insert(id, todo);
}

pub fn get(&self, id: &Uuid) -> Option<Todo> {
self.0.read().unwrap().get(id).cloned()
}

pub fn remove(&self, id: &Uuid) -> Option<Todo> {
self.0.write().unwrap().remove(id)
}
pub fn values(&self) -> Vec<Todo> {
self.0.read().unwrap().values().cloned().collect()
}
}

pub fn create(s: SocketRef, Data(data): Data<PartialTodo>, ack: AckSender) {
pub fn create(s: SocketRef, Data(data): Data<PartialTodo>, ack: AckSender, todos: State<Todos>) {
let id = Uuid::new_v4();
let todo = Todo { id, inner: data };

get_store().write().unwrap().insert(id, todo.clone());
todos.insert(id, todo.clone());

let res: Response<_> = id.into();
ack.send(res).ok();

s.broadcast().emit("todo:created", todo).ok();
}

pub async fn read(Data(id): Data<Uuid>, ack: AckSender) {
let todos = get_store().read().unwrap();

pub async fn read(Data(id): Data<Uuid>, ack: AckSender, todos: State<Todos>) {
let todo = todos.get(&id).ok_or(Error::NotFound);
ack.send(todo).ok();
}

pub async fn update(s: SocketRef, Data(data): Data<Todo>, ack: AckSender) {
let mut todos = get_store().write().unwrap();
let res = todos.get_mut(&data.id).ok_or(Error::NotFound).map(|todo| {
todo.inner = data.inner.clone();
s.broadcast().emit("todo:updated", data).ok();
});
pub async fn update(s: SocketRef, Data(data): Data<Todo>, ack: AckSender, todos: State<Todos>) {
let res = todos
.0
.write()
.unwrap()
.get_mut(&data.id)
.ok_or(Error::NotFound)
.map(|todo| {
todo.inner = data.inner.clone();
s.broadcast().emit("todo:updated", data).ok();
});

ack.send(res).ok();
}

pub async fn delete(s: SocketRef, Data(id): Data<Uuid>, ack: AckSender) {
let mut todos = get_store().write().unwrap();
pub async fn delete(s: SocketRef, Data(id): Data<Uuid>, ack: AckSender, todos: State<Todos>) {
let res = todos.remove(&id).ok_or(Error::NotFound).map(|_| {
s.broadcast().emit("todo:deleted", id).ok();
});

ack.send(res).ok();
}

pub async fn list(ack: AckSender) {
let todos = get_store().read().unwrap();
let res: Response<_> = todos.values().cloned().collect::<Vec<_>>().into();
pub async fn list(ack: AckSender, todos: State<Todos>) {
let res: Response<_> = todos.values().into();
info!("Sending todos: {:?}", res);
ack.send(res).ok();
}
6 changes: 5 additions & 1 deletion examples/basic-crud-application/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use tower_http::{cors::CorsLayer, services::ServeDir};
use tracing::{error, info};
use tracing_subscriber::FmtSubscriber;

use crate::handlers::todo::Todos;

mod handlers;

#[tokio::main]
Expand All @@ -16,7 +18,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

info!("Starting server");

let (layer, io) = SocketIo::new_layer();
let (layer, io) = SocketIo::builder()
.with_state(Todos::default())
.build_layer();

io.ns("/", |s: SocketRef| {
s.on("todo:create", handlers::todo::create);
Expand Down
10 changes: 7 additions & 3 deletions socketioxide/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@
pub struct Client<A: Adapter> {
pub(crate) config: Arc<SocketIoConfig>,
ns: RwLock<HashMap<Cow<'static, str>, Arc<Namespace<A>>>>,

state: StateCell,

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope
}

impl<A: Adapter> Client<A> {
pub fn new(config: Arc<SocketIoConfig>) -> Self {
pub fn new(config: Arc<SocketIoConfig>, state: StateCell) -> Self {

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope
Self {
config,
ns: RwLock::new(HashMap::new()),

state,
}
}

Expand All @@ -45,7 +49,7 @@

let sid = esocket.id;
if let Some(ns) = self.get_ns(ns_path) {
ns.connect(sid, esocket.clone(), auth, self.config.clone())?;
ns.connect(sid, esocket.clone(), auth, self.config.clone(), &self.state)?;

// cancel the connect timeout task for v5
if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() {
Expand Down Expand Up @@ -73,7 +77,7 @@
/// Propagate a packet to a its target namespace
fn sock_propagate_packet(&self, packet: Packet<'_>, sid: Sid) -> Result<(), Error> {
if let Some(ns) = self.get_ns(&packet.ns) {
ns.recv(sid, packet.inner)
ns.recv(sid, packet.inner, &self.state)
} else {
#[cfg(feature = "tracing")]
tracing::debug!("invalid namespace requested: {}", packet.ns);
Expand Down
46 changes: 30 additions & 16 deletions socketioxide/src/handler/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
/// A Type Erased [`ConnectHandler`] so it can be stored in a HashMap
pub(crate) type BoxedConnectHandler<A> = Box<dyn ErasedConnectHandler<A>>;
pub(crate) trait ErasedConnectHandler<A: Adapter>: Send + Sync + 'static {
fn call(&self, s: Arc<Socket<A>>, auth: Option<String>);
fn call(&self, s: Arc<Socket<A>>, auth: Option<String>, state: &StateCell);

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope
}

impl<A: Adapter, T, H> MakeErasedHandler<H, A, T>
Expand All @@ -82,8 +82,8 @@
T: Send + Sync + 'static,
{
#[inline(always)]
fn call(&self, s: Arc<Socket<A>>, auth: Option<String>) {
self.handler.call(s, auth);
fn call(&self, s: Arc<Socket<A>>, auth: Option<String>, state: &StateCell) {

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope
self.handler.call(s, auth, state);
}
}

Expand All @@ -99,7 +99,11 @@

/// Extract the arguments from the connect event.
/// If it fails, the handler is not called
fn from_connect_parts(s: &Arc<Socket<A>>, auth: &Option<String>) -> Result<Self, Self::Error>;
fn from_connect_parts(
s: &Arc<Socket<A>>,
auth: &Option<String>,
state: &StateCell,

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope
) -> Result<Self, Self::Error>;
}

/// Define a handler for the connect event.
Expand All @@ -109,7 +113,7 @@
/// * See the [`extract`](super::extract) module doc for more details on available extractors.
pub trait ConnectHandler<A: Adapter, T>: Send + Sync + 'static {
/// Call the handler with the given arguments.
fn call(&self, s: Arc<Socket<A>>, auth: Option<String>);
fn call(&self, s: Arc<Socket<A>>, auth: Option<String>, state: &StateCell);

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

#[doc(hidden)]
fn phantom(&self) -> std::marker::PhantomData<T> {
Expand All @@ -136,15 +140,20 @@
A: Adapter,
$( $ty: FromConnectParts<A> + Send, )*
{
fn call(&self, s: Arc<Socket<A>>, auth: Option<String>) {
fn call(
&self,
s: Arc<Socket<A>>,
auth: Option<String>,
state: &StateCell)

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope
{
$(
let $ty = match $ty::from_connect_parts(&s, &auth) {
Ok(v) => v,
Err(_e) => {
#[cfg(feature = "tracing")]
tracing::error!("Error while extracting data: {}", _e);
return;
},
let $ty = match $ty::from_connect_parts(&s, &auth, state) {
Ok(v) => v,
Err(_e) => {
#[cfg(feature = "tracing")]
tracing::error!("Error while extracting data: {}", _e);
return;
},
};
)*

Expand All @@ -163,13 +172,18 @@
#[allow(non_snake_case, unused)]
impl<A, F, $($ty,)*> ConnectHandler<A, (private::Sync, $($ty,)*)> for F
where
F: FnOnce($($ty,)*) + Send + Sync + Clone + 'static,
F: FnOnce($($ty,)*) + Send + Sync + Clone + 'static,
A: Adapter,
$( $ty: FromConnectParts<A> + Send, )*
{
fn call(&self, s: Arc<Socket<A>>, auth: Option<String>) {
fn call(
&self,
s: Arc<Socket<A>>,
auth: Option<String>,
state: &StateCell)

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope

Check failure

Code scanning / clippy

cannot find type StateCell in this scope Error

cannot find type StateCell in this scope
{
$(
let $ty = match $ty::from_connect_parts(&s, &auth) {
let $ty = match $ty::from_connect_parts(&s, &auth, state) {
Ok(v) => v,
Err(_e) => {
#[cfg(feature = "tracing")]
Expand Down
Loading
Loading