Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unsubscribe #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*.exe

/.vscode/
/.idea/

# Generated by Cargo
/target/
Expand Down
34 changes: 20 additions & 14 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,24 @@ use mqtt3::{QoS, ToTopicPath, TopicPath};

use mio_more::channel::*;

use std::sync::atomic::{AtomicUsize, Ordering};

use MqttOptions;

#[allow(unused)]
#[derive(DebugStub)]
pub enum Command {
Status(#[debug_stub = ""] ::std::sync::mpsc::Sender<::state::MqttConnectionStatus>),
Subscribe(Subscription),
Unsubscribe(SubscriptionToken),
Publish(Publish),
Connect,
Disconnect,
}

pub struct MqttClient {
nw_request_tx: SyncSender<Command>,
subscription_id_source: AtomicUsize
}

impl MqttClient {
Expand Down Expand Up @@ -73,6 +77,7 @@ impl MqttClient {

Ok(MqttClient {
nw_request_tx: commands_tx,
subscription_id_source: AtomicUsize::new(0),
})
}

Expand All @@ -84,14 +89,18 @@ impl MqttClient {
Ok(SubscriptionBuilder {
client: self,
it: Subscription {
id: None,
id: self.subscription_id_source.fetch_add(1, Ordering::Relaxed),
topic_path: topic_path.to_topic_path()?,
qos: ::mqtt3::QoS::AtMostOnce,
callback,
},
})
}

pub fn unsubscribe(&self, token: SubscriptionToken) -> Result<()> {
self.send_command(Command::Unsubscribe(token))
}

pub fn publish<T: ToTopicPath>(&self, topic_path: T) -> Result<PublishBuilder> {
Ok(PublishBuilder {
client: self,
Expand Down Expand Up @@ -126,7 +135,7 @@ pub type SubscriptionCallback = Box<Fn(&::mqtt3::Publish) + Send>;

#[derive(DebugStub)]
pub struct Subscription {
pub id: Option<String>,
pub id: usize,
pub topic_path: TopicPath,
pub qos: ::mqtt3::QoS,
#[debug_stub = ""] pub callback: SubscriptionCallback,
Expand All @@ -139,28 +148,25 @@ pub struct SubscriptionBuilder<'a> {
}

impl<'a> SubscriptionBuilder<'a> {
pub fn id<S: ToString>(self, s: S) -> SubscriptionBuilder<'a> {
let SubscriptionBuilder { client, it } = self;
SubscriptionBuilder {
client,
it: Subscription {
id: Some(s.to_string()),
..it
},
}
}
pub fn qos(self, qos: QoS) -> SubscriptionBuilder<'a> {
let SubscriptionBuilder { client, it } = self;
SubscriptionBuilder {
client,
it: Subscription { qos, ..it },
}
}
pub fn send(self) -> Result<()> {
self.client.send_command(Command::Subscribe(self.it))
pub fn send(self) -> Result<SubscriptionToken> {
let token = SubscriptionToken { id: self.it.id};
self.client.send_command(Command::Subscribe(self.it))?;
Ok(token)
}
}

#[derive(Debug)]
pub struct SubscriptionToken {
pub id: usize
}

#[derive(Debug)]
pub struct Publish {
pub topic: TopicPath,
Expand Down
6 changes: 6 additions & 0 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ impl ConnectionState {
self.turn_command()?;
}
mqtt3::Packet::Suback(suback) => self.mqtt_state.handle_incoming_suback(suback)?,
mqtt3::Packet::Unsuback(packet_identifier) => self.mqtt_state.handle_incoming_unsuback(packet_identifier)?,
mqtt3::Packet::Publish(publish) => {
let (_, server) = self.mqtt_state.handle_incoming_publish(publish)?;
if let Some(server) = server {
Expand Down Expand Up @@ -437,6 +438,11 @@ impl ConnectionState {
let packet = self.mqtt_state.handle_outgoing_subscribe(vec![sub])?;
self.send_packet(mqtt3::Packet::Subscribe(packet))?
}
Command::Unsubscribe(token) => {
if let Some(packet) = self.mqtt_state.handle_outgoing_unsubscribe(vec![token.id])? {
self.send_packet(mqtt3::Packet::Unsubscribe(packet))?
}
}
Command::Status(tx) => {
let _ = tx.send(self.state().status());
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@ mod state;

pub use rustls::ClientConfig as RustlsConfig;
pub use options::{MqttOptions, ReconnectOptions, TlsOptions};
pub use client::MqttClient;
pub use client::{MqttClient, SubscriptionToken};
pub use mqtt3::{Message, Publish, QoS, ToTopicPath, TopicPath};
91 changes: 65 additions & 26 deletions src/state.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::time::{Duration, Instant};
use std::collections::VecDeque;

use error::*;
use mqtt3;
use MqttOptions;
use error::*;
use std::collections::{HashMap, VecDeque};
use std::time::{Duration, Instant};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MqttConnectionStatus {
Expand Down Expand Up @@ -41,7 +40,8 @@ pub struct MqttState {
// Even so, if broker crashes, all its state will be lost (most brokers).
// client should resubscribe it comes back up again or else the data will
// be lost
subscriptions: Vec<::client::Subscription>,
subscriptions: HashMap<usize, ::client::Subscription>,
path_usage: HashMap<String, usize>,
}

/// Design: `MqttState` methods will just modify the state of the object
Expand All @@ -60,7 +60,8 @@ impl MqttState {
last_flush: Instant::now(),
last_pkid: mqtt3::PacketIdentifier(0),
outgoing_pub: VecDeque::new(),
subscriptions: Vec::new(),
subscriptions: HashMap::new(),
path_usage: HashMap::new(),
}
}

Expand Down Expand Up @@ -97,13 +98,13 @@ impl MqttState {
use self::MqttConnectionStatus::*;
use ReconnectOptions::*;
match (self.connection_status, self.opts.reconnect) {
(Handshake { initial: true }, Always(d))
| (Handshake {..}, AfterFirstSuccess(d))
(Handshake { initial: true }, Always(d))
| (Handshake { .. }, AfterFirstSuccess(d))
| (Connected, AfterFirstSuccess(d))
| (Connected, Always(d))
| (WantConnect { .. }, AfterFirstSuccess(d))
| (WantConnect { .. }, Always(d))
=> self.connection_status = WantConnect { when: Instant::now()+d },
=> self.connection_status = WantConnect { when: Instant::now() + d },
_ => self.connection_status = Disconnected
}
}
Expand Down Expand Up @@ -131,13 +132,14 @@ impl MqttState {
} else {
let sub = if self.subscriptions.len() > 0 {
Some(mqtt3::Subscribe {
pid: self.next_pkid(),
topics: self.subscriptions.iter().map(|s| {
pid: self.next_pkid(),
topics: self.subscriptions.iter().map(|(_id, s)| {
::mqtt3::SubscribeTopic {
topic_path: s.topic_path.path.clone(),
qos: s.qos,
}}).collect()
})
}
}).collect(),
})
} else {
None
};
Expand Down Expand Up @@ -202,7 +204,7 @@ impl MqttState {
let qos = publish.qos;

let concrete = ::mqtt3::TopicPath::from_str(&publish.topic_name)?;
for sub in &self.subscriptions {
for (_id, sub) in &self.subscriptions {
if sub.topic_path.is_match(&concrete) {
(sub.callback)(&publish);
}
Expand Down Expand Up @@ -289,7 +291,12 @@ impl MqttState {
}
})
.collect();
self.subscriptions.extend(subs);
for s in &subs {
*self.path_usage.entry(s.topic_path.path.clone()).or_insert(0) += 1;
}
self.subscriptions.extend(subs.into_iter().map(|it| {
(it.id, it)
}));

if self.connection_status == MqttConnectionStatus::Connected {
Ok(mqtt3::Subscribe { pid: pkid, topics })
Expand All @@ -302,14 +309,47 @@ impl MqttState {
}
}

pub fn handle_outgoing_unsubscribe(
&mut self,
ids: Vec<usize>,
) -> Result<Option<mqtt3::Unsubscribe>> {
let mut topics = vec![];
for id in ids {
if let Some(sub) = self.subscriptions.remove(&id) {
// we unwrap here because if the value is not there, there is an error in this code
let mut path_count = self.path_usage.get_mut(&sub.topic_path.path).unwrap();
*path_count -= 1;
if *path_count == 0 { topics.push(sub.topic_path.path) }
}
}
if !topics.is_empty() {
let pkid = self.next_pkid();

if self.connection_status == MqttConnectionStatus::Connected {
Ok(Some(mqtt3::Unsubscribe { pid: pkid, topics }))
} else {
error!(
"State = {:?}. Shouldn't unsubscribe in this state",
self.connection_status
);
Err(ErrorKind::InvalidState.into())
}
} else {
Ok(None)
}
}

pub fn handle_incoming_suback(&mut self, ack: mqtt3::Suback) -> Result<()> {
if ack.return_codes
.iter()
.any(|v| *v == ::mqtt3::SubscribeReturnCodes::Failure)
{
Err(format!("rejected subscription"))?
};
{
Err(format!("rejected subscription"))?
};
Ok(())
}

pub fn handle_incoming_unsuback(&mut self, ack: mqtt3::PacketIdentifier) -> Result<()> {
Ok(())
}

Expand Down Expand Up @@ -340,14 +380,13 @@ impl MqttState {

#[cfg(test)]
mod test {
use error::*;
use mqtt3::*;
use options::MqttOptions;
use std::sync::Arc;
use std::thread;
use std::time::Duration;

use super::{MqttConnectionStatus, MqttState};
use mqtt3::*;
use options::MqttOptions;
use error::*;

#[test]
fn next_pkid_roll() {
Expand Down Expand Up @@ -546,7 +585,7 @@ mod test {
mqtt.handle_socket_disconnect();
assert_eq!(mqtt.outgoing_pub.len(), 0);
match mqtt.connection_status {
MqttConnectionStatus::WantConnect { .. } => {},
MqttConnectionStatus::WantConnect { .. } => {}
_ => panic!()
}
assert_eq!(mqtt.await_pingresp, false);
Expand Down Expand Up @@ -574,7 +613,7 @@ mod test {
mqtt.handle_socket_disconnect();
assert_eq!(mqtt.outgoing_pub.len(), 3);
match mqtt.connection_status {
MqttConnectionStatus::WantConnect { .. } => {},
MqttConnectionStatus::WantConnect { .. } => {}
_ => panic!()
}
assert_eq!(mqtt.await_pingresp, false);
Expand All @@ -585,7 +624,7 @@ mod test {
let mut mqtt = MqttState::new(MqttOptions::new("test-id", "127.0.0.1:1883"));

match mqtt.connection_status {
MqttConnectionStatus::WantConnect { .. } => {},
MqttConnectionStatus::WantConnect { .. } => {}
_ => panic!()
}
mqtt.handle_outgoing_connect(true);
Expand All @@ -609,7 +648,7 @@ mod test {

assert!(mqtt.handle_incoming_connack(connack).is_err());
match mqtt.connection_status {
MqttConnectionStatus::WantConnect { .. } => {},
MqttConnectionStatus::WantConnect { .. } => {}
_ => panic!()
}
}
Expand Down
Loading