From 1ede2eacc3840244de2809ecf1de5293b65d6851 Mon Sep 17 00:00:00 2001 From: Thibaut Lorrain Date: Mon, 6 Aug 2018 17:09:41 +0200 Subject: [PATCH] add unsubscribe --- .gitignore | 1 + src/client.rs | 34 ++++++++++------- src/connection.rs | 6 +++ src/lib.rs | 2 +- src/state.rs | 91 +++++++++++++++++++++++++++++++++------------- tests/testsuite.rs | 73 +++++++++++++++++++++++++++++++++++++ 6 files changed, 166 insertions(+), 41 deletions(-) diff --git a/.gitignore b/.gitignore index 5e6c472..9a4f955 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ *.exe /.vscode/ +/.idea/ # Generated by Cargo /target/ diff --git a/src/client.rs b/src/client.rs index f301789..4122558 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,6 +4,8 @@ use mqtt3::{QoS, ToTopicPath, TopicPath}; use mio_more::channel::*; +use std::sync::atomic::{AtomicUsize, Ordering}; + use MqttOptions; #[allow(unused)] @@ -11,6 +13,7 @@ use MqttOptions; pub enum Command { Status(#[debug_stub = ""] ::std::sync::mpsc::Sender<::state::MqttConnectionStatus>), Subscribe(Subscription), + Unsubscribe(SubscriptionToken), Publish(Publish), Connect, Disconnect, @@ -18,6 +21,7 @@ pub enum Command { pub struct MqttClient { nw_request_tx: SyncSender, + subscription_id_source: AtomicUsize } impl MqttClient { @@ -73,6 +77,7 @@ impl MqttClient { Ok(MqttClient { nw_request_tx: commands_tx, + subscription_id_source: AtomicUsize::new(0), }) } @@ -84,7 +89,7 @@ impl MqttClient { Ok(SubscriptionBuilder { client: self, it: Subscription { - id: None, + id: self.subscription_id_source.fetch_add(1, Ordering::Relaxed), topic_path: topic_path.to_topic_path()?, qos: ::mqtt3::QoS::AtMostOnce, callback, @@ -92,6 +97,10 @@ impl MqttClient { }) } + pub fn unsubscribe(&self, token: SubscriptionToken) -> Result<()> { + self.send_command(Command::Unsubscribe(token)) + } + pub fn publish(&self, topic_path: T) -> Result { Ok(PublishBuilder { client: self, @@ -126,7 +135,7 @@ pub type SubscriptionCallback = Box; #[derive(DebugStub)] pub struct Subscription { - pub id: Option, + pub id: usize, pub topic_path: TopicPath, pub qos: ::mqtt3::QoS, #[debug_stub = ""] pub callback: SubscriptionCallback, @@ -139,16 +148,6 @@ pub struct SubscriptionBuilder<'a> { } impl<'a> SubscriptionBuilder<'a> { - pub fn id(self, s: S) -> SubscriptionBuilder<'a> { - let SubscriptionBuilder { client, it } = self; - SubscriptionBuilder { - client, - it: Subscription { - id: Some(s.to_string()), - ..it - }, - } - } pub fn qos(self, qos: QoS) -> SubscriptionBuilder<'a> { let SubscriptionBuilder { client, it } = self; SubscriptionBuilder { @@ -156,11 +155,18 @@ impl<'a> SubscriptionBuilder<'a> { it: Subscription { qos, ..it }, } } - pub fn send(self) -> Result<()> { - self.client.send_command(Command::Subscribe(self.it)) + pub fn send(self) -> Result { + let token = SubscriptionToken { id: self.it.id}; + self.client.send_command(Command::Subscribe(self.it))?; + Ok(token) } } +#[derive(Debug)] +pub struct SubscriptionToken { + pub id: usize +} + #[derive(Debug)] pub struct Publish { pub topic: TopicPath, diff --git a/src/connection.rs b/src/connection.rs index e0210c1..af9e878 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -389,6 +389,7 @@ impl ConnectionState { self.turn_command()?; } mqtt3::Packet::Suback(suback) => self.mqtt_state.handle_incoming_suback(suback)?, + mqtt3::Packet::Unsuback(packet_identifier) => self.mqtt_state.handle_incoming_unsuback(packet_identifier)?, mqtt3::Packet::Publish(publish) => { let (_, server) = self.mqtt_state.handle_incoming_publish(publish)?; if let Some(server) = server { @@ -437,6 +438,11 @@ impl ConnectionState { let packet = self.mqtt_state.handle_outgoing_subscribe(vec![sub])?; self.send_packet(mqtt3::Packet::Subscribe(packet))? } + Command::Unsubscribe(token) => { + if let Some(packet) = self.mqtt_state.handle_outgoing_unsubscribe(vec![token.id])? { + self.send_packet(mqtt3::Packet::Unsubscribe(packet))? + } + } Command::Status(tx) => { let _ = tx.send(self.state().status()); } diff --git a/src/lib.rs b/src/lib.rs index 13f9e5f..01d70dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,5 +43,5 @@ mod state; pub use rustls::ClientConfig as RustlsConfig; pub use options::{MqttOptions, ReconnectOptions, TlsOptions}; -pub use client::MqttClient; +pub use client::{MqttClient, SubscriptionToken}; pub use mqtt3::{Message, Publish, QoS, ToTopicPath, TopicPath}; diff --git a/src/state.rs b/src/state.rs index 7cd5d63..53261bb 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,9 +1,8 @@ -use std::time::{Duration, Instant}; -use std::collections::VecDeque; - +use error::*; use mqtt3; use MqttOptions; -use error::*; +use std::collections::{HashMap, VecDeque}; +use std::time::{Duration, Instant}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MqttConnectionStatus { @@ -41,7 +40,8 @@ pub struct MqttState { // Even so, if broker crashes, all its state will be lost (most brokers). // client should resubscribe it comes back up again or else the data will // be lost - subscriptions: Vec<::client::Subscription>, + subscriptions: HashMap, + path_usage: HashMap, } /// Design: `MqttState` methods will just modify the state of the object @@ -60,7 +60,8 @@ impl MqttState { last_flush: Instant::now(), last_pkid: mqtt3::PacketIdentifier(0), outgoing_pub: VecDeque::new(), - subscriptions: Vec::new(), + subscriptions: HashMap::new(), + path_usage: HashMap::new(), } } @@ -97,13 +98,13 @@ impl MqttState { use self::MqttConnectionStatus::*; use ReconnectOptions::*; match (self.connection_status, self.opts.reconnect) { - (Handshake { initial: true }, Always(d)) - | (Handshake {..}, AfterFirstSuccess(d)) + (Handshake { initial: true }, Always(d)) + | (Handshake { .. }, AfterFirstSuccess(d)) | (Connected, AfterFirstSuccess(d)) | (Connected, Always(d)) | (WantConnect { .. }, AfterFirstSuccess(d)) | (WantConnect { .. }, Always(d)) - => self.connection_status = WantConnect { when: Instant::now()+d }, + => self.connection_status = WantConnect { when: Instant::now() + d }, _ => self.connection_status = Disconnected } } @@ -131,13 +132,14 @@ impl MqttState { } else { let sub = if self.subscriptions.len() > 0 { Some(mqtt3::Subscribe { - pid: self.next_pkid(), - topics: self.subscriptions.iter().map(|s| { + pid: self.next_pkid(), + topics: self.subscriptions.iter().map(|(_id, s)| { ::mqtt3::SubscribeTopic { topic_path: s.topic_path.path.clone(), qos: s.qos, - }}).collect() - }) + } + }).collect(), + }) } else { None }; @@ -202,7 +204,7 @@ impl MqttState { let qos = publish.qos; let concrete = ::mqtt3::TopicPath::from_str(&publish.topic_name)?; - for sub in &self.subscriptions { + for (_id, sub) in &self.subscriptions { if sub.topic_path.is_match(&concrete) { (sub.callback)(&publish); } @@ -289,7 +291,12 @@ impl MqttState { } }) .collect(); - self.subscriptions.extend(subs); + for s in &subs { + *self.path_usage.entry(s.topic_path.path.clone()).or_insert(0) += 1; + } + self.subscriptions.extend(subs.into_iter().map(|it| { + (it.id, it) + })); if self.connection_status == MqttConnectionStatus::Connected { Ok(mqtt3::Subscribe { pid: pkid, topics }) @@ -302,14 +309,47 @@ impl MqttState { } } + pub fn handle_outgoing_unsubscribe( + &mut self, + ids: Vec, + ) -> Result> { + let mut topics = vec![]; + for id in ids { + if let Some(sub) = self.subscriptions.remove(&id) { + // we unwrap here because if the value is not there, there is an error in this code + let mut path_count = self.path_usage.get_mut(&sub.topic_path.path).unwrap(); + *path_count -= 1; + if *path_count == 0 { topics.push(sub.topic_path.path) } + } + } + if !topics.is_empty() { + let pkid = self.next_pkid(); + + if self.connection_status == MqttConnectionStatus::Connected { + Ok(Some(mqtt3::Unsubscribe { pid: pkid, topics })) + } else { + error!( + "State = {:?}. Shouldn't unsubscribe in this state", + self.connection_status + ); + Err(ErrorKind::InvalidState.into()) + } + } else { + Ok(None) + } + } pub fn handle_incoming_suback(&mut self, ack: mqtt3::Suback) -> Result<()> { if ack.return_codes .iter() .any(|v| *v == ::mqtt3::SubscribeReturnCodes::Failure) - { - Err(format!("rejected subscription"))? - }; + { + Err(format!("rejected subscription"))? + }; + Ok(()) + } + + pub fn handle_incoming_unsuback(&mut self, ack: mqtt3::PacketIdentifier) -> Result<()> { Ok(()) } @@ -340,14 +380,13 @@ impl MqttState { #[cfg(test)] mod test { + use error::*; + use mqtt3::*; + use options::MqttOptions; use std::sync::Arc; use std::thread; use std::time::Duration; - use super::{MqttConnectionStatus, MqttState}; - use mqtt3::*; - use options::MqttOptions; - use error::*; #[test] fn next_pkid_roll() { @@ -546,7 +585,7 @@ mod test { mqtt.handle_socket_disconnect(); assert_eq!(mqtt.outgoing_pub.len(), 0); match mqtt.connection_status { - MqttConnectionStatus::WantConnect { .. } => {}, + MqttConnectionStatus::WantConnect { .. } => {} _ => panic!() } assert_eq!(mqtt.await_pingresp, false); @@ -574,7 +613,7 @@ mod test { mqtt.handle_socket_disconnect(); assert_eq!(mqtt.outgoing_pub.len(), 3); match mqtt.connection_status { - MqttConnectionStatus::WantConnect { .. } => {}, + MqttConnectionStatus::WantConnect { .. } => {} _ => panic!() } assert_eq!(mqtt.await_pingresp, false); @@ -585,7 +624,7 @@ mod test { let mut mqtt = MqttState::new(MqttOptions::new("test-id", "127.0.0.1:1883")); match mqtt.connection_status { - MqttConnectionStatus::WantConnect { .. } => {}, + MqttConnectionStatus::WantConnect { .. } => {} _ => panic!() } mqtt.handle_outgoing_connect(true); @@ -609,7 +648,7 @@ mod test { assert!(mqtt.handle_incoming_connack(connack).is_err()); match mqtt.connection_status { - MqttConnectionStatus::WantConnect { .. } => {}, + MqttConnectionStatus::WantConnect { .. } => {} _ => panic!() } } diff --git a/tests/testsuite.rs b/tests/testsuite.rs index b796ec0..a757d86 100644 --- a/tests/testsuite.rs +++ b/tests/testsuite.rs @@ -117,6 +117,79 @@ fn basic_publishes_and_subscribes() { assert_eq!(3, final_count.load(Ordering::SeqCst)); } +#[test] +fn publishes_and_subscribes_and_unsubscribes() { + // loggerv::init_with_level(log::LogLevel::Debug); + let client_options = MqttOptions::new("pubsubunsub", MOSQUITTO_ADDR); + let count = Arc::new(AtomicUsize::new(0)); + let final_count = count.clone(); + let count = count.clone(); + + let count2 = Arc::new(AtomicUsize::new(0)); + let final_count2 = count2.clone(); + let count2 = count2.clone(); + + let request = MqttClient::start(client_options).expect("Coudn't start"); + let token = request + .subscribe( + "test/pubsubunsub", + Box::new(move |_| { + count.fetch_add(1, Ordering::SeqCst); + }), + ) + .unwrap() + .send() + .unwrap(); + + let token2 = request + .subscribe( + "test/pubsubunsub", + Box::new(move |_| { + count2.fetch_add(1, Ordering::SeqCst); + }), + ) + .unwrap() + .send() + .unwrap(); + + let payload = format!("hello rust"); + request + .publish("test/pubsubunsub") + .unwrap() + .payload(payload.clone().into_bytes()) + .send() + .unwrap(); + + thread::sleep(Duration::from_secs(1)); + request.unsubscribe(token).unwrap(); + thread::sleep(Duration::from_secs(1)); + + request + .publish("test/pubsubunsub") + .unwrap() + .payload(payload.clone().into_bytes()) + .send() + .unwrap(); + + thread::sleep(Duration::from_secs(1)); + + request.unsubscribe(token2).unwrap(); + thread::sleep(Duration::from_secs(1)); + + request + .publish("test/pubsubunsub") + .unwrap() + .payload(payload.clone().into_bytes()) + .send() + .unwrap(); + + thread::sleep(Duration::from_secs(1)); + + assert_eq!(1, final_count.load(Ordering::SeqCst)); + assert_eq!(2, final_count2.load(Ordering::SeqCst)); +} + + #[test] fn alive() { // loggerv::init_with_level(log::LogLevel::Debug);