diff --git a/Cargo.toml b/Cargo.toml index 52c5813d..f75a47a1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,6 +75,7 @@ waker-fn = "^1.1" [dev-dependencies] async-global-executor = "^2.0" +async-io = "^2.0" futures-lite = "^2.0" serde_json = "^1.0" waker-fn = "^1.1" diff --git a/examples/c.rs b/examples/c.rs new file mode 100644 index 00000000..e870c6e7 --- /dev/null +++ b/examples/c.rs @@ -0,0 +1,58 @@ +use futures_lite::StreamExt; +use lapin::{options::*, types::FieldTable, Connection, ConnectionProperties}; +use tracing::info; + +fn main() { + if std::env::var("RUST_LOG").is_err() { + unsafe { std::env::set_var("RUST_LOG", "info") }; + } + + tracing_subscriber::fmt::init(); + + let addr = std::env::var("AMQP_ADDR").unwrap_or_else(|_| "amqp://127.0.0.1:5672/%2f".into()); + + async_global_executor::block_on(async { + let conn = Connection::connect(&addr, ConnectionProperties::default()) + .await + .expect("connection error"); + + info!("CONNECTED"); + + //receive channel + let channel = conn.create_channel().await.expect("create_channel"); + info!(state=?conn.status().state()); + + let queue = channel + .queue_declare( + "hello-recover", + QueueDeclareOptions::default(), + FieldTable::default(), + ) + .await + .expect("queue_declare"); + info!(state=?conn.status().state()); + info!(?queue, "Declared queue"); + + info!("will consume"); + let mut consumer = channel + .basic_consume( + "hello-recover", + "my_consumer", + BasicConsumeOptions::default(), + FieldTable::default(), + ) + .await + .expect("basic_consume"); + info!(state=?conn.status().state()); + + while let Some(delivery) = consumer.next().await { + info!(message=?delivery, "received message"); + if let Ok(delivery) = delivery { + delivery + .ack(BasicAckOptions::default()) + .await + .expect("basic_ack"); + } + } + }) +} diff --git a/examples/p.rs b/examples/p.rs new file mode 100644 index 00000000..7abb29fa --- /dev/null +++ b/examples/p.rs @@ -0,0 +1,100 @@ +use lapin::{ + options::*, types::FieldTable, BasicProperties, Connection, ConnectionProperties, +}; +use tracing::info; + +fn main() { + if std::env::var("RUST_LOG").is_err() { + std::env::set_var("RUST_LOG", "info"); + } + + tracing_subscriber::fmt::init(); + + let addr = std::env::var("AMQP_ADDR").unwrap_or_else(|_| "amqp://127.0.0.1:5672/%2f".into()); + let recovery_config = lapin::experimental::RecoveryConfig { + auto_recover_channels: true, + }; + + async_global_executor::block_on(async { + let conn = Connection::connect( + &addr, + ConnectionProperties::default().with_experimental_recovery_config(recovery_config), + ) + .await + .expect("connection error"); + + info!("CONNECTED"); + + let channel1 = conn.create_channel().await.expect("create_channel"); + channel1 + .confirm_select(ConfirmSelectOptions::default()) + .await + .expect("confirm_select"); + channel1 + .queue_declare( + "hello-recover", + QueueDeclareOptions::default(), + FieldTable::default(), + ) + .await + .expect("queue_declare"); + + let ch = channel1.clone(); + async_global_executor::spawn(async move { + loop { + async_io::Timer::after(std::time::Duration::from_secs(1)).await; + info!("Trigger failure"); + assert!(ch + .queue_declare( + "fake queue", + QueueDeclareOptions { + passive: true, + ..QueueDeclareOptions::default() + }, + FieldTable::default(), + ) + .await + .is_err()); + } + }) + .detach(); + + let mut published = 0; + let mut errors = 0; + info!("will publish"); + loop { + let res = channel1 + .basic_publish( + "", + "recover-test", + BasicPublishOptions::default(), + b"before", + BasicProperties::default(), + ) + .await; + let res = if let Ok(res) = res { + res.await.map(|_| ()) + } else { + res.map(|_| ()) + }; + match res { + Ok(()) => { + println!("GOT OK"); + published += 1; + } + Err(err) => { + println!("GOT ERROR"); + let (soft, notifier) = err.is_amqp_soft_error(); + if !soft { + panic!("{}", err); + } + errors += 1; + if let Some(notifier) = notifier { + notifier.await + } + } + } + println!("Published {} with {} errors", published, errors); + } + }); +} diff --git a/examples/t.rs b/examples/t.rs new file mode 100644 index 00000000..22dfbc6f --- /dev/null +++ b/examples/t.rs @@ -0,0 +1,122 @@ +use lapin::{ + message::DeliveryResult, options::*, publisher_confirm::Confirmation, types::FieldTable, + BasicProperties, Connection, ConnectionProperties, +}; +use tracing::info; + +fn main() { + if std::env::var("RUST_LOG").is_err() { + std::env::set_var("RUST_LOG", "info"); + } + + tracing_subscriber::fmt::init(); + + let addr = std::env::var("AMQP_ADDR").unwrap_or_else(|_| "amqp://127.0.0.1:5672/%2f".into()); + let recovery_config = lapin::experimental::RecoveryConfig { + auto_recover_channels: true, + }; + + async_global_executor::block_on(async { + let conn = Connection::connect( + &addr, + ConnectionProperties::default().with_experimental_recovery_config(recovery_config), + ) + .await + .expect("connection error"); + + info!("CONNECTED"); + + { + let channel1 = conn.create_channel().await.expect("create_channel"); + let channel2 = conn.create_channel().await.expect("create_channel"); + channel1 + .confirm_select(ConfirmSelectOptions::default()) + .await + .expect("confirm_select"); + channel1 + .queue_declare( + "recover-test", + QueueDeclareOptions::default(), + FieldTable::default(), + ) + .await + .expect("queue_declare"); + + info!("will consume"); + let channel = channel2.clone(); + channel2 + .basic_consume( + "recover-test", + "my_consumer", + BasicConsumeOptions::default(), + FieldTable::default(), + ) + .await + .expect("basic_consume") + .set_delegate(move |delivery: DeliveryResult| { + let channel = channel.clone(); + async move { + info!(message=?delivery, "received message"); + if let Ok(Some(delivery)) = delivery { + delivery + .ack(BasicAckOptions::default()) + .await + .expect("basic_ack"); + if &delivery.data[..] == b"after" { + channel + .basic_cancel("my_consumer", BasicCancelOptions::default()) + .await + .expect("basic_cancel"); + } + } + } + }); + + info!("will publish"); + let confirm = channel1 + .basic_publish( + "", + "recover-test", + BasicPublishOptions::default(), + b"before", + BasicProperties::default(), + ) + .await + .expect("basic_publish") + .await + .expect("publisher-confirms"); + assert_eq!(confirm, Confirmation::Ack(None)); + + info!("before fail"); + assert!(channel1 + .queue_declare( + "fake queue", + QueueDeclareOptions { + passive: true, + ..QueueDeclareOptions::default() + }, + FieldTable::default(), + ) + .await + .is_err()); + info!("after fail"); + + info!("publish after"); + let confirm = channel1 + .basic_publish( + "", + "recover-test", + BasicPublishOptions::default(), + b"after", + BasicProperties::default(), + ) + .await + .expect("basic_publish") + .await + .expect("publisher-confirms"); + assert_eq!(confirm, Confirmation::Ack(None)); + } + + conn.run().expect("conn.run"); + }); +} diff --git a/src/acker.rs b/src/acker.rs index 401426b0..53494c13 100644 --- a/src/acker.rs +++ b/src/acker.rs @@ -78,10 +78,13 @@ impl Acker { async fn rpc)>(&self, f: F) -> Result<()> { if self.used.swap(true, Ordering::SeqCst) { - return Err(Error::ProtocolError(AMQPError::new( - AMQPSoftError::PRECONDITIONFAILED.into(), - "Attempted to use an already used Acker".into(), - ))); + return Err(Error::ProtocolError( + AMQPError::new( + AMQPSoftError::PRECONDITIONFAILED.into(), + "Attempted to use an already used Acker".into(), + ), + None, + )); } if let Some(error) = self.error.as_ref() { error.check()?; diff --git a/src/acknowledgement.rs b/src/acknowledgement.rs index f5ce8886..5bf65a73 100644 --- a/src/acknowledgement.rs +++ b/src/acknowledgement.rs @@ -62,6 +62,10 @@ impl Acknowledgements { pub(crate) fn on_channel_error(&self, error: Error) { self.0.lock().on_channel_error(error); } + + pub(crate) fn reset(&self, error: Error) { + self.0.lock().reset(error); + } } impl fmt::Debug for Acknowledgements { @@ -174,4 +178,9 @@ impl Inner { } } } + + fn reset(&mut self, error: Error) { + self.delivery_tag = IdSequence::new(false); + self.on_channel_error(error); + } } diff --git a/src/channel.rs b/src/channel.rs index 198b45ff..da7fdd93 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -86,6 +86,7 @@ impl fmt::Debug for Channel { } impl Channel { + #[allow(clippy::too_many_arguments)] pub(crate) fn new( channel_id: ChannelId, configuration: Configuration, @@ -139,10 +140,6 @@ impl Channel { self.error_handler.set_handler(handler); } - pub(crate) fn reset(&self) { - // FIXME - } - pub(crate) async fn restore( &self, ch: &ChannelDefinitionInternal, @@ -298,7 +295,7 @@ impl Channel { class_id, method_id, ); - Err(Error::ProtocolError(error)) + Err(Error::ProtocolError(error, None)) } } @@ -429,7 +426,7 @@ impl Channel { class_id, method_id, ); - Err(Error::ProtocolError(error)) + Err(Error::ProtocolError(error, None)) } pub(crate) fn handle_content_header_frame( @@ -468,7 +465,7 @@ impl Channel { class_id, 0, ); - let error = Error::ProtocolError(error); + let error = Error::ProtocolError(error, None); self.set_connection_error(error.clone()); Err(error) }, @@ -537,7 +534,7 @@ impl Channel { ) .await }); - Err(Error::ProtocolError(err)) + Err(Error::ProtocolError(err, None)) } fn before_connection_start_ok( @@ -556,7 +553,7 @@ impl Channel { } fn on_connection_close_ok_sent(&self, error: Error) { - if let Error::ProtocolError(_) = error { + if let Error::ProtocolError(_, _) = error { self.internal_rpc.set_connection_error(error); } else { self.internal_rpc.set_connection_closed(error); @@ -564,8 +561,10 @@ impl Channel { } fn next_expected_close_ok_reply(&self) -> Option { - self.frames - .next_expected_close_ok_reply(self.id, Error::InvalidChannelState(ChannelState::Closed)) + self.frames.next_expected_close_ok_reply( + self.id, + Error::InvalidChannelState(ChannelState::Closed, None), + ) } fn before_channel_close(&self) { @@ -573,13 +572,17 @@ impl Channel { } fn on_channel_close_ok_sent(&self, error: Option) { - self.set_closed( - error - .clone() - .unwrap_or(Error::InvalidChannelState(ChannelState::Closing)), - ); - if let Some(error) = error { - self.error_handler.on_error(error); + if !self.recovery_config.auto_recover_channels + || !error.as_ref().map_or(false, |e| e.is_amqp_soft_error().0) + { + self.set_closed( + error + .clone() + .unwrap_or(Error::InvalidChannelState(ChannelState::Closing, None)), + ); + if let Some(error) = error { + self.error_handler.on_error(error); + } } } @@ -819,7 +822,7 @@ impl Channel { ?error, "Connection closed", ); - Error::ProtocolError(error) + Error::ProtocolError(error, None) }) .unwrap_or_else(|error| { error!(%error); @@ -862,7 +865,19 @@ impl Channel { resolver: PromiseResolver, channel: Channel, ) -> Result<()> { - self.set_state(ChannelState::Connected); + if self.recovery_config.auto_recover_channels { + self.status.update_recovery_context(|ctx| { + ctx.set_expected_replies(self.frames.take_expected_replies(self.id)); + self.frames.drop_frames_for_channel(channel.id, ctx.cause()); + self.acknowledgements.reset(ctx.cause()); + self.consumers.error(ctx.cause()); + }); + if !self.status.confirm() { + self.status.finalize_recovery(); + } + } else { + self.set_state(ChannelState::Connected); + } resolver.resolve(channel); Ok(()) } @@ -896,18 +911,37 @@ impl Channel { channel=%self.id, ?method, ?error, "Channel closed" ); - Error::ProtocolError(error) + Error::ProtocolError(error, None) }); - self.set_closing(error.clone().ok()); + match ( + self.recovery_config.auto_recover_channels, + error.clone().ok(), + ) { + (true, Some(error)) if error.is_amqp_soft_error().0 => { + self.status.set_reconnecting(error) + } + (_, err) => self.set_closing(err), + } let error = error.map_err(|error| info!(channel=%self.id, ?method, code_to_error=%error, "Channel closed with a non-error code")).ok(); let channel = self.clone(); - self.internal_rpc - .register_internal_future(async move { channel.channel_close_ok(error).await }); + self.internal_rpc.register_internal_future(async move { + channel.channel_close_ok(error).await?; + if channel.recovery_config.auto_recover_channels { + let ch = channel.clone(); + channel.channel_open(ch).await?; + if channel.status.confirm() { + channel + .confirm_select(ConfirmSelectOptions::default()) + .await?; + } + } + Ok(()) + }); Ok(()) } fn on_channel_close_ok_received(&self) -> Result<()> { - self.set_closed(Error::InvalidChannelState(ChannelState::Closed)); + self.set_closed(Error::InvalidChannelState(ChannelState::Closed, None)); Ok(()) } diff --git a/src/channel_recovery_context.rs b/src/channel_recovery_context.rs new file mode 100644 index 00000000..936123f1 --- /dev/null +++ b/src/channel_recovery_context.rs @@ -0,0 +1,46 @@ +use crate::{ + frames::{ExpectedReply, Frames}, + notifier::Notifier, + Error, +}; + +use std::collections::VecDeque; + +pub(crate) struct ChannelRecoveryContext { + cause: Error, + expected_replies: Option>, + notifier: Notifier, +} + +impl ChannelRecoveryContext { + pub(crate) fn new(cause: Error) -> Self { + let notifier = Notifier::default(); + Self { + cause: cause.with_notifier(notifier.clone()), + expected_replies: None, + notifier, + } + } + + pub(crate) fn cause(&self) -> Error { + self.cause.clone() + } + + pub(crate) fn notifier(&self) -> Notifier { + self.notifier.clone() + } + + pub(crate) fn set_expected_replies( + &mut self, + expected_replies: Option>, + ) { + self.expected_replies = expected_replies; + } + + pub(crate) fn finalize_recovery(self) { + self.notifier.notify_all(); + if let Some(replies) = self.expected_replies { + Frames::cancel_expected_replies(replies, self.cause); + } + } +} diff --git a/src/channel_status.rs b/src/channel_status.rs index 230020e2..82dd464d 100644 --- a/src/channel_status.rs +++ b/src/channel_status.rs @@ -1,7 +1,9 @@ use crate::{ channel_receiver_state::{ChannelReceiverStates, DeliveryCause}, + channel_recovery_context::ChannelRecoveryContext, + notifier::Notifier, types::{ChannelId, Identifier, PayloadSize}, - Result, + Error, Result, }; use parking_lot::Mutex; use std::{fmt, sync::Arc}; @@ -12,19 +14,43 @@ pub struct ChannelStatus(Arc>); impl ChannelStatus { pub fn initializing(&self) -> bool { - self.0.lock().state == ChannelState::Initial + [ChannelState::Initial, ChannelState::Reconnecting].contains(&self.0.lock().state) } pub fn closing(&self) -> bool { - self.0.lock().state == ChannelState::Closing + [ChannelState::Closing, ChannelState::Reconnecting].contains(&self.0.lock().state) } pub fn connected(&self) -> bool { self.0.lock().state == ChannelState::Connected } + pub fn reconnecting(&self) -> bool { + self.0.lock().state == ChannelState::Reconnecting + } + + pub(crate) fn connected_or_recovering(&self) -> bool { + [ChannelState::Connected, ChannelState::Reconnecting].contains(&self.0.lock().state) + } + + pub(crate) fn update_recovery_context(&self, apply: F) { + let mut inner = self.0.lock(); + if let Some(context) = inner.recovery_context.as_mut() { + apply(context); + } + } + + pub(crate) fn finalize_recovery(&self) { + self.0.lock().finalize_recovery(); + } + pub(crate) fn can_receive_messages(&self) -> bool { - [ChannelState::Closing, ChannelState::Connected].contains(&self.0.lock().state) + [ + ChannelState::Closing, + ChannelState::Connected, + ChannelState::Reconnecting, + ] + .contains(&self.0.lock().state) } pub fn confirm(&self) -> bool { @@ -32,18 +58,31 @@ impl ChannelStatus { } pub(crate) fn set_confirm(&self) { - self.0.lock().confirm = true; + let mut inner = self.0.lock(); + inner.confirm = true; trace!("Publisher confirms activated"); + inner.finalize_recovery(); } pub fn state(&self) -> ChannelState { - self.0.lock().state.clone() + self.0.lock().state } pub(crate) fn set_state(&self, state: ChannelState) { self.0.lock().state = state; } + pub fn state_error(&self) -> Error { + let inner = self.0.lock(); + Error::InvalidChannelState(inner.state, inner.notifier()) + } + + pub(crate) fn set_reconnecting(&self, error: Error) { + let mut inner = self.0.lock(); + inner.state = ChannelState::Reconnecting; + inner.recovery_context = Some(ChannelRecoveryContext::new(error)); + } + pub(crate) fn auto_close(&self, id: ChannelId) -> bool { id != 0 && self.0.lock().state == ChannelState::Connected } @@ -112,10 +151,11 @@ impl ChannelStatus { } } -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] pub enum ChannelState { #[default] Initial, + Reconnecting, Connected, Closing, Closed, @@ -141,6 +181,7 @@ struct Inner { send_flow: bool, state: ChannelState, receiver_state: ChannelReceiverStates, + recovery_context: Option, } impl Default for Inner { @@ -150,6 +191,20 @@ impl Default for Inner { send_flow: true, state: ChannelState::default(), receiver_state: ChannelReceiverStates::default(), + recovery_context: None, } } } + +impl Inner { + pub(crate) fn finalize_recovery(&mut self) { + self.state = ChannelState::Connected; + if let Some(ctx) = self.recovery_context.take() { + ctx.finalize_recovery(); + } + } + + fn notifier(&self) -> Option { + Some(self.recovery_context.as_ref()?.notifier()) + } +} diff --git a/src/channels.rs b/src/channels.rs index 06481dd6..630fd706 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -225,7 +225,7 @@ impl Channels { .await }); } - return Err(Error::ProtocolError(error)); + return Err(Error::ProtocolError(error, None)); } } AMQPFrame::Header(channel_id, class_id, header) => { @@ -248,7 +248,7 @@ impl Channels { .await }); } - return Err(Error::ProtocolError(error)); + return Err(Error::ProtocolError(error, None)); } else { self.handle_content_header_frame( channel_id, diff --git a/src/connection.rs b/src/connection.rs index e457ae37..e5a55d50 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -152,7 +152,6 @@ impl Connection { .channels .push(RestoredChannel::new(if let Some(c) = c.channel.clone() { let channel = c.clone(); - c.reset(); c.channel_open(channel).await? } else { self.create_channel().await? diff --git a/src/error.rs b/src/error.rs index 235d19ae..1528a4c8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ use crate::{ - channel_status::ChannelState, connection_status::ConnectionState, protocol::AMQPError, - types::ChannelId, + channel_status::ChannelState, connection_status::ConnectionState, notifier::Notifier, + protocol::AMQPError, types::ChannelId, }; use amq_protocol::{ frame::{GenError, ParserError, ProtocolVersion}, @@ -22,12 +22,12 @@ pub enum Error { InvalidProtocolVersion(ProtocolVersion), InvalidChannel(ChannelId), - InvalidChannelState(ChannelState), + InvalidChannelState(ChannelState, Option), InvalidConnectionState(ConnectionState), IOError(Arc), ParsingError(ParserError), - ProtocolError(AMQPError), + ProtocolError(AMQPError, Option), SerialisationError(Arc), MissingHeartbeatError, @@ -53,23 +53,30 @@ impl Error { } } - pub fn is_amqp_soft_error(&self) -> bool { - if let Error::ProtocolError(e) = self { + pub fn is_amqp_soft_error(&self) -> (bool, Option) { + if let Error::ProtocolError(e, notifier) = self { if let AMQPErrorKind::Soft(_) = e.kind() { - return true; + return (true, notifier.clone()); } } - false + (false, None) } pub fn is_amqp_hard_error(&self) -> bool { - if let Error::ProtocolError(e) = self { + if let Error::ProtocolError(e, _) = self { if let AMQPErrorKind::Hard(_) = e.kind() { return true; } } false } + + pub(crate) fn with_notifier(self, notifier: Notifier) -> Self { + match self { + Self::ProtocolError(err, _) => Self::ProtocolError(err, Some(notifier)), + err => err, + } + } } impl fmt::Display for Error { @@ -84,14 +91,14 @@ impl fmt::Display for Error { } Error::InvalidChannel(channel) => write!(f, "invalid channel: {}", channel), - Error::InvalidChannelState(state) => write!(f, "invalid channel state: {:?}", state), + Error::InvalidChannelState(state, _) => write!(f, "invalid channel state: {:?}", state), Error::InvalidConnectionState(state) => { write!(f, "invalid connection state: {:?}", state) } Error::IOError(e) => write!(f, "IO error: {}", e), Error::ParsingError(e) => write!(f, "failed to parse: {}", e), - Error::ProtocolError(e) => write!(f, "protocol error: {}", e), + Error::ProtocolError(e, _) => write!(f, "protocol error: {}", e), Error::SerialisationError(e) => write!(f, "failed to serialise: {}", e), Error::MissingHeartbeatError => { @@ -119,7 +126,7 @@ impl error::Error for Error { match self { Error::IOError(e) => Some(&**e), Error::ParsingError(e) => Some(e), - Error::ProtocolError(e) => Some(e), + Error::ProtocolError(e, _) => Some(e), Error::SerialisationError(e) => Some(&**e), _ => None, } @@ -144,7 +151,7 @@ impl PartialEq for Error { } (InvalidChannel(left_inner), InvalidChannel(right_inner)) => left_inner == right_inner, - (InvalidChannelState(left_inner), InvalidChannelState(right_inner)) => { + (InvalidChannelState(left_inner, _), InvalidChannelState(right_inner, _)) => { left_inner == right_inner } (InvalidConnectionState(left_inner), InvalidConnectionState(right_inner)) => { @@ -156,7 +163,9 @@ impl PartialEq for Error { false } (ParsingError(left_inner), ParsingError(right_inner)) => left_inner == right_inner, - (ProtocolError(left_inner), ProtocolError(right_inner)) => left_inner == right_inner, + (ProtocolError(left_inner, _), ProtocolError(right_inner, _)) => { + left_inner == right_inner + } (SerialisationError(_), SerialisationError(_)) => { error!("Unable to compare lapin::Error::SerialisationError"); false diff --git a/src/frames.rs b/src/frames.rs index 139f6df9..2472f43c 100644 --- a/src/frames.rs +++ b/src/frames.rs @@ -71,7 +71,7 @@ impl Frames { pub(crate) fn next_expected_close_ok_reply( &self, - channel_id: u16, + channel_id: ChannelId, error: Error, ) -> Option { self.inner @@ -87,8 +87,25 @@ impl Frames { self.inner.lock().drop_pending(error); } + pub(crate) fn take_expected_replies( + &self, + channel_id: ChannelId, + ) -> Option> { + self.inner.lock().expected_replies.remove(&channel_id) + } + pub(crate) fn clear_expected_replies(&self, channel_id: ChannelId, error: Error) { - self.inner.lock().clear_expected_replies(channel_id, error); + if let Some(replies) = self.take_expected_replies(channel_id) { + Self::cancel_expected_replies(replies, error) + } + } + + pub(crate) fn cancel_expected_replies(replies: VecDeque, error: Error) { + Inner::cancel_expected_replies(replies, error) + } + + pub(crate) fn drop_frames_for_channel(&self, channel_id: ChannelId, error: Error) { + self.inner.lock().drop_frames_for_channel(channel_id, error) } pub(crate) fn poison(&self) -> Option { @@ -253,7 +270,36 @@ impl Inner { } } - fn next_expected_close_ok_reply(&mut self, channel_id: u16, error: Error) -> Option { + fn drop_frames_for_channel(&mut self, channel_id: ChannelId, error: Error) { + Self::drop_pending_frames_for_channel(channel_id, &mut self.retry_frames, error.clone()); + Self::drop_pending_frames_for_channel(channel_id, &mut self.publish_frames, error.clone()); + Self::drop_pending_frames_for_channel(channel_id, &mut self.frames, error.clone()); + Self::drop_pending_frames_for_channel(channel_id, &mut self.low_prio_frames, error); + } + + fn drop_pending_frames_for_channel( + channel_id: ChannelId, + frames: &mut VecDeque<(AMQPFrame, Option>)>, + error: Error, + ) { + use AMQPFrame::*; + + frames.retain(|(f, r)| match f { + Method(id, _) | Header(id, _, _) | Body(id, _) | Heartbeat(id) if *id == channel_id => { + if let Some(r) = r { + r.reject(error.clone()); + } + false + } + _ => true, + }) + } + + fn next_expected_close_ok_reply( + &mut self, + channel_id: ChannelId, + error: Error, + ) -> Option { let expected_replies = self.expected_replies.get_mut(&channel_id)?; while let Some(reply) = expected_replies.pop_front() { match &reply.0 { @@ -265,12 +311,6 @@ impl Inner { None } - fn clear_expected_replies(&mut self, channel_id: ChannelId, error: Error) { - if let Some(replies) = self.expected_replies.remove(&channel_id) { - Self::cancel_expected_replies(replies, error); - } - } - fn cancel_expected_replies(replies: VecDeque, error: Error) { for ExpectedReply(reply, cancel) in replies { match reply { diff --git a/src/generated.rs b/src/generated.rs index 4abae5f2..6fbdb386 100644 --- a/src/generated.rs +++ b/src/generated.rs @@ -384,7 +384,7 @@ impl Channel { options: BasicQosOptions, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let BasicQosOptions { global } = options; @@ -414,7 +414,7 @@ impl Channel { } fn receive_basic_qos_ok(&self, method: protocol::basic::QosOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self @@ -445,7 +445,7 @@ impl Channel { original: Option, ) -> Result { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let creation_arguments = arguments.clone(); @@ -500,7 +500,7 @@ impl Channel { } fn receive_basic_consume_ok(&self, method: protocol::basic::ConsumeOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -538,7 +538,7 @@ impl Channel { options: BasicCancelOptions, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.before_basic_cancel(consumer_tag); @@ -576,13 +576,13 @@ impl Channel { } fn receive_basic_cancel(&self, method: protocol::basic::Cancel) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_basic_cancel_received(method) } async fn basic_cancel_ok(&self, consumer_tag: &str) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Basic(protocol::basic::AMQPMethod::CancelOk( @@ -600,7 +600,7 @@ impl Channel { } fn receive_basic_cancel_ok(&self, method: protocol::basic::CancelOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -630,7 +630,7 @@ impl Channel { properties: BasicProperties, ) -> Result { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let start_hook_res = self.before_basic_publish(); @@ -652,13 +652,13 @@ impl Channel { } fn receive_basic_return(&self, method: protocol::basic::Return) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_basic_return_received(method) } fn receive_basic_deliver(&self, method: protocol::basic::Deliver) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_basic_deliver_received(method) } @@ -669,7 +669,7 @@ impl Channel { original: Option>>, ) -> Result> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let BasicGetOptions { no_ack } = options; @@ -700,7 +700,7 @@ impl Channel { } fn receive_basic_get_ok(&self, method: protocol::basic::GetOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self @@ -722,7 +722,7 @@ impl Channel { } fn receive_basic_get_empty(&self, method: protocol::basic::GetEmpty) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_basic_get_empty_received(method) } @@ -732,7 +732,7 @@ impl Channel { options: BasicAckOptions, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let BasicAckOptions { multiple } = options; @@ -751,7 +751,7 @@ impl Channel { } fn receive_basic_ack(&self, method: protocol::basic::Ack) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_basic_ack_received(method) } @@ -761,7 +761,7 @@ impl Channel { options: BasicRejectOptions, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let BasicRejectOptions { requeue } = options; @@ -781,7 +781,7 @@ impl Channel { } pub async fn basic_recover_async(&self, options: BasicRecoverAsyncOptions) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let BasicRecoverAsyncOptions { requeue } = options; @@ -799,7 +799,7 @@ impl Channel { } pub async fn basic_recover(&self, options: BasicRecoverOptions) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let BasicRecoverOptions { requeue } = options; @@ -828,7 +828,7 @@ impl Channel { } fn receive_basic_recover_ok(&self, method: protocol::basic::RecoverOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -855,7 +855,7 @@ impl Channel { options: BasicNackOptions, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let BasicNackOptions { multiple, requeue } = options; @@ -875,14 +875,14 @@ impl Channel { } fn receive_basic_nack(&self, method: protocol::basic::Nack) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_basic_nack_received(method) } fn receive_connection_start(&self, method: protocol::connection::Start) -> Result<()> { self.assert_channel0(method.get_amqp_class_id(), method.get_amqp_method_id())?; if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_connection_start_received(method) } @@ -897,7 +897,7 @@ impl Channel { credentials: Credentials, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Connection(protocol::connection::AMQPMethod::StartOk( @@ -920,13 +920,13 @@ impl Channel { fn receive_connection_secure(&self, method: protocol::connection::Secure) -> Result<()> { self.assert_channel0(method.get_amqp_class_id(), method.get_amqp_method_id())?; if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_connection_secure_received(method) } async fn connection_secure_ok(&self, response: &str) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Connection(protocol::connection::AMQPMethod::SecureOk( @@ -945,7 +945,7 @@ impl Channel { fn receive_connection_tune(&self, method: protocol::connection::Tune) -> Result<()> { self.assert_channel0(method.get_amqp_class_id(), method.get_amqp_method_id())?; if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_connection_tune_received(method) } @@ -956,7 +956,7 @@ impl Channel { heartbeat: ShortUInt, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Connection(protocol::connection::AMQPMethod::TuneOk( @@ -981,7 +981,7 @@ impl Channel { conn_resolver: PromiseResolver, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Connection(protocol::connection::AMQPMethod::Open( @@ -1013,7 +1013,7 @@ impl Channel { fn receive_connection_open_ok(&self, method: protocol::connection::OpenOk) -> Result<()> { self.assert_channel0(method.get_amqp_class_id(), method.get_amqp_method_id())?; if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -1042,7 +1042,7 @@ impl Channel { method_id: ShortUInt, ) -> Result<()> { if !self.status.closing() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Connection(protocol::connection::AMQPMethod::Close( @@ -1076,13 +1076,13 @@ impl Channel { fn receive_connection_close(&self, method: protocol::connection::Close) -> Result<()> { self.assert_channel0(method.get_amqp_class_id(), method.get_amqp_method_id())?; if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_connection_close_received(method) } pub(crate) async fn connection_close_ok(&self, error: Error) -> Result<()> { if !self.status.closing() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Connection(protocol::connection::AMQPMethod::CloseOk( @@ -1100,7 +1100,7 @@ impl Channel { fn receive_connection_close_ok(&self, method: protocol::connection::CloseOk) -> Result<()> { self.assert_channel0(method.get_amqp_class_id(), method.get_amqp_method_id())?; if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -1123,7 +1123,7 @@ impl Channel { } pub(crate) async fn connection_blocked(&self, reason: &str) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Connection(protocol::connection::AMQPMethod::Blocked( @@ -1142,13 +1142,13 @@ impl Channel { fn receive_connection_blocked(&self, method: protocol::connection::Blocked) -> Result<()> { self.assert_channel0(method.get_amqp_class_id(), method.get_amqp_method_id())?; if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_connection_blocked_received(method) } pub(crate) async fn connection_unblocked(&self) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Connection(protocol::connection::AMQPMethod::Unblocked( @@ -1165,7 +1165,7 @@ impl Channel { fn receive_connection_unblocked(&self, method: protocol::connection::Unblocked) -> Result<()> { self.assert_channel0(method.get_amqp_class_id(), method.get_amqp_method_id())?; if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_connection_unblocked_received(method) } @@ -1175,7 +1175,7 @@ impl Channel { reason: &str, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Connection(protocol::connection::AMQPMethod::UpdateSecret( @@ -1210,7 +1210,7 @@ impl Channel { ) -> Result<()> { self.assert_channel0(method.get_amqp_class_id(), method.get_amqp_method_id())?; if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| matches!(&reply.0, Reply::ConnectionUpdateSecretOk(..))){ @@ -1227,7 +1227,7 @@ impl Channel { } pub(crate) async fn channel_open(&self, channel: Channel) -> Result { if !self.status.initializing() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Channel(protocol::channel::AMQPMethod::Open( @@ -1255,7 +1255,7 @@ impl Channel { } fn receive_channel_open_ok(&self, method: protocol::channel::OpenOk) -> Result<()> { if !self.status.initializing() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -1276,7 +1276,7 @@ impl Channel { } pub async fn channel_flow(&self, options: ChannelFlowOptions) -> Result { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let ChannelFlowOptions { active } = options; @@ -1305,13 +1305,13 @@ impl Channel { } fn receive_channel_flow(&self, method: protocol::channel::Flow) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_channel_flow_received(method) } async fn channel_flow_ok(&self, options: ChannelFlowOkOptions) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let ChannelFlowOkOptions { active } = options; @@ -1328,7 +1328,7 @@ impl Channel { } fn receive_channel_flow_ok(&self, method: protocol::channel::FlowOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -1355,7 +1355,7 @@ impl Channel { method_id: ShortUInt, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.before_channel_close(); @@ -1389,13 +1389,13 @@ impl Channel { } fn receive_channel_close(&self, method: protocol::channel::Close) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_channel_close_received(method) } async fn channel_close_ok(&self, error: Option) -> Result<()> { if !self.status.closing() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Channel(protocol::channel::AMQPMethod::CloseOk( @@ -1412,7 +1412,7 @@ impl Channel { } fn receive_channel_close_ok(&self, method: protocol::channel::CloseOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.next_expected_close_ok_reply() { @@ -1433,7 +1433,7 @@ impl Channel { } pub async fn access_request(&self, realm: &str, options: AccessRequestOptions) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let AccessRequestOptions { @@ -1475,7 +1475,7 @@ impl Channel { } fn receive_access_request_ok(&self, method: protocol::access::RequestOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -1505,7 +1505,7 @@ impl Channel { exchange_kind: ExchangeKind, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let creation_arguments = arguments.clone(); @@ -1559,7 +1559,7 @@ impl Channel { } fn receive_exchange_declare_ok(&self, method: protocol::exchange::DeclareOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -1595,7 +1595,7 @@ impl Channel { options: ExchangeDeleteOptions, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let ExchangeDeleteOptions { if_unused, nowait } = options; @@ -1631,7 +1631,7 @@ impl Channel { } fn receive_exchange_delete_ok(&self, method: protocol::exchange::DeleteOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -1661,7 +1661,7 @@ impl Channel { arguments: FieldTable, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let creation_arguments = arguments.clone(); @@ -1706,7 +1706,7 @@ impl Channel { } fn receive_exchange_bind_ok(&self, method: protocol::exchange::BindOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -1747,7 +1747,7 @@ impl Channel { arguments: FieldTable, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let creation_arguments = arguments.clone(); @@ -1792,7 +1792,7 @@ impl Channel { } fn receive_exchange_unbind_ok(&self, method: protocol::exchange::UnbindOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -1831,7 +1831,7 @@ impl Channel { arguments: FieldTable, ) -> Result { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let creation_arguments = arguments.clone(); @@ -1881,7 +1881,7 @@ impl Channel { } fn receive_queue_declare_ok(&self, method: protocol::queue::DeclareOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -1909,7 +1909,7 @@ impl Channel { arguments: FieldTable, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let creation_arguments = arguments.clone(); @@ -1952,7 +1952,7 @@ impl Channel { } fn receive_queue_bind_ok(&self, method: protocol::queue::BindOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self @@ -1991,7 +1991,7 @@ impl Channel { options: QueuePurgeOptions, ) -> Result { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let QueuePurgeOptions { nowait } = options; @@ -2021,7 +2021,7 @@ impl Channel { } fn receive_queue_purge_ok(&self, method: protocol::queue::PurgeOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self @@ -2047,7 +2047,7 @@ impl Channel { options: QueueDeleteOptions, ) -> Result { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let QueueDeleteOptions { @@ -2090,7 +2090,7 @@ impl Channel { } fn receive_queue_delete_ok(&self, method: protocol::queue::DeleteOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -2117,7 +2117,7 @@ impl Channel { arguments: FieldTable, ) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let creation_arguments = arguments.clone(); @@ -2157,7 +2157,7 @@ impl Channel { } fn receive_queue_unbind_ok(&self, method: protocol::queue::UnbindOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { @@ -2191,7 +2191,7 @@ impl Channel { } pub async fn tx_select(&self) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Tx(protocol::tx::AMQPMethod::Select(protocol::tx::Select {})); @@ -2217,7 +2217,7 @@ impl Channel { } fn receive_tx_select_ok(&self, method: protocol::tx::SelectOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self @@ -2241,7 +2241,7 @@ impl Channel { } pub async fn tx_commit(&self) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Tx(protocol::tx::AMQPMethod::Commit(protocol::tx::Commit {})); @@ -2267,7 +2267,7 @@ impl Channel { } fn receive_tx_commit_ok(&self, method: protocol::tx::CommitOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self @@ -2291,7 +2291,7 @@ impl Channel { } pub async fn tx_rollback(&self) -> Result<()> { if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } let method = AMQPClass::Tx(protocol::tx::AMQPMethod::Rollback( @@ -2319,7 +2319,7 @@ impl Channel { } fn receive_tx_rollback_ok(&self, method: protocol::tx::RollbackOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self @@ -2342,8 +2342,8 @@ impl Channel { } } pub async fn confirm_select(&self, options: ConfirmSelectOptions) -> Result<()> { - if !self.status.connected() { - return Err(Error::InvalidChannelState(self.status.state())); + if !self.status.connected_or_recovering() { + return Err(self.status.state_error()); } let ConfirmSelectOptions { nowait } = options; @@ -2372,7 +2372,7 @@ impl Channel { } fn receive_confirm_select_ok(&self, method: protocol::confirm::SelectOk) -> Result<()> { if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match self.frames.find_expected_reply(self.id, |reply| { diff --git a/src/id_sequence.rs b/src/id_sequence.rs index 89294b55..d8d069d3 100644 --- a/src/id_sequence.rs +++ b/src/id_sequence.rs @@ -29,7 +29,7 @@ impl< // self.id is actually the next (so that first call to next returns 0 // if we're 0 (or 1 if 0 is not allowed), either we haven't started yet, or last number we yielded (current one) is // the max. - if self.id == self.first() { + if self.id <= self.first() { self.max } else { Some(self.id - self.one) diff --git a/src/io_loop.rs b/src/io_loop.rs index 33113a66..6de68fed 100644 --- a/src/io_loop.rs +++ b/src/io_loop.rs @@ -440,7 +440,7 @@ impl IoLoop { 0, 0, ); - self.critical_error(Error::ProtocolError(error))?; + self.critical_error(Error::ProtocolError(error, None))?; } self.receive_buffer.consume(consumed); Ok(Some(f)) diff --git a/src/lib.rs b/src/lib.rs index 9f3d8d6f..f77efb1c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -136,6 +136,7 @@ mod buffer; mod channel; mod channel_closer; mod channel_receiver_state; +mod channel_recovery_context; mod channel_status; mod channels; mod configuration; @@ -156,6 +157,7 @@ mod id_sequence; mod internal_rpc; mod io_loop; mod killswitch; +mod notifier; mod parsing; mod promise; mod queue; diff --git a/src/notifier.rs b/src/notifier.rs new file mode 100644 index 00000000..bd0ba038 --- /dev/null +++ b/src/notifier.rs @@ -0,0 +1,48 @@ +use crate::wakers::Wakers; + +use std::{ + fmt, + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +#[derive(Default, Clone)] +pub struct Notifier { + done: Arc, + wakers: Arc, +} + +impl Notifier { + pub(crate) fn notify_all(&self) { + self.done.store(true, Ordering::Release); + self.wakers.wake(); + } + + fn ready(&self) -> bool { + self.done.load(Ordering::Acquire) + } +} + +impl Future for Notifier { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.ready() { + Poll::Ready(()) + } else { + self.wakers.register(cx.waker()); + Poll::Pending + } + } +} + +impl fmt::Debug for Notifier { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Notifier").finish() + } +} diff --git a/templates/channel.rs b/templates/channel.rs index 466f042e..cd64af04 100644 --- a/templates/channel.rs +++ b/templates/channel.rs @@ -69,10 +69,14 @@ impl Channel { {{#if method.metadata.channel_deinit ~}} if !self.status.closing() { {{else}} + {{#if method.metadata.channel_recovery ~}} + if !self.status.connected_or_recovering() { + {{else}} if !self.status.connected() { {{/if ~}} {{/if ~}} - return Err(Error::InvalidChannelState(self.status.state())); + {{/if ~}} + return Err(self.status.state_error()); } {{#if method.metadata.start_hook ~}} @@ -167,7 +171,7 @@ impl Channel { {{else}} if !self.status.can_receive_messages() { {{/if ~}} - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } match {{#if method.metadata.expected_reply_getter ~}}{{method.metadata.expected_reply_getter}}{{else}}self.frames.find_expected_reply(self.id, |reply| matches!(&reply.0, Reply::{{camel class.name}}{{camel method.name}}(..))){{/if ~}} { @@ -201,7 +205,7 @@ impl Channel { )?; {{/if ~}} if !self.status.can_receive_messages() { - return Err(Error::InvalidChannelState(self.status.state())); + return Err(self.status.state_error()); } self.on_{{snake class.name false}}_{{snake method.name false}}_received(method) } diff --git a/templates/lapin.json b/templates/lapin.json index 74861dab..56269b8f 100644 --- a/templates/lapin.json +++ b/templates/lapin.json @@ -151,6 +151,11 @@ } }, "confirm": { + "select": { + "metadata": { + "channel_recovery": true + } + }, "select-ok": { "metadata": { "received_hook": true