From 68ef2ad987d577082a3039abb0090cd35b5aba64 Mon Sep 17 00:00:00 2001 From: Daniele Palaia Date: Sun, 10 Nov 2024 17:21:00 +0100 Subject: [PATCH] Implementing callback support and consumer_update response --- protocol/src/commands/consumer_update.rs | 8 +++ protocol/src/request/mod.rs | 5 -- src/client/dispatcher.rs | 2 +- src/client/mod.rs | 12 ++++ src/consumer.rs | 87 +++++++++++++++++++----- src/environment.rs | 2 + src/lib.rs | 6 +- src/superstream_consumer.rs | 20 +++++- tests/integration/client_test.rs | 13 ---- 9 files changed, 116 insertions(+), 39 deletions(-) diff --git a/protocol/src/commands/consumer_update.rs b/protocol/src/commands/consumer_update.rs index 463ad0c..6d07472 100644 --- a/protocol/src/commands/consumer_update.rs +++ b/protocol/src/commands/consumer_update.rs @@ -27,6 +27,14 @@ impl ConsumerUpdateCommand { active, } } + + pub fn get_correlation_id(&self) -> u32 { + self.correlation_id + } + + pub fn is_active(&self) -> u8 { + self.active + } } impl Encoder for ConsumerUpdateCommand { diff --git a/protocol/src/request/mod.rs b/protocol/src/request/mod.rs index 31bd805..fe48803 100644 --- a/protocol/src/request/mod.rs +++ b/protocol/src/request/mod.rs @@ -397,9 +397,4 @@ mod tests { fn request_route_command() { request_encode_decode_test::() } - - #[test] - fn request_consumer_update_request_command() { - request_encode_decode_test::() - } } diff --git a/src/client/dispatcher.rs b/src/client/dispatcher.rs index 5ac3b7f..1abcf43 100644 --- a/src/client/dispatcher.rs +++ b/src/client/dispatcher.rs @@ -168,7 +168,7 @@ where match result { Ok(item) => match item.correlation_id() { Some(correlation_id) => match item.kind_ref() { - ResponseKind::ConsumerUpdate(consumer_update) => state.notify(item).await, + ResponseKind::ConsumerUpdate(_) => state.notify(item).await, _ => state.dispatch(correlation_id, item).await, }, None => state.notify(item).await, diff --git a/src/client/mod.rs b/src/client/mod.rs index 920186b..98bab85 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -42,6 +42,7 @@ pub use options::ClientOptions; use rabbitmq_stream_protocol::{ commands::{ close::{CloseRequest, CloseResponse}, + consumer_update_request::ConsumerUpdateRequestCommand, create_stream::CreateStreamCommand, create_super_stream::CreateSuperStreamCommand, credit::CreditCommand, @@ -851,4 +852,15 @@ impl Client { Ok(config) } + + pub async fn consumer_update( + &self, + correlation_id: u32, + offset_specification: OffsetSpecification, + ) -> RabbitMQStreamResult { + self.send_and_receive(|_| { + ConsumerUpdateRequestCommand::new(correlation_id, 1, offset_specification) + }) + .await + } } diff --git a/src/consumer.rs b/src/consumer.rs index a4650ce..f651b94 100644 --- a/src/consumer.rs +++ b/src/consumer.rs @@ -15,6 +15,8 @@ use rabbitmq_stream_protocol::{ commands::subscribe::OffsetSpecification, message::Message, ResponseKind, }; +use core::option::Option::None; + use tokio::sync::mpsc::{channel, Receiver, Sender}; use tracing::trace; @@ -26,13 +28,13 @@ use crate::{ Client, ClientOptions, Environment, MetricsCollector, }; use futures::{task::AtomicWaker, Stream}; -use rabbitmq_stream_protocol::commands::consumer_update::ConsumerUpdateCommand; use rand::rngs::StdRng; use rand::{seq::SliceRandom, SeedableRng}; type FilterPredicate = Option bool + Send + Sync>>; -type ConsumerUpdateListener = Option u64 + Send + Sync>>; +pub type ConsumerUpdateListener = + Arc OffsetSpecification + Send + Sync>; /// API for consuming RabbitMQ stream messages pub struct Consumer { @@ -43,6 +45,7 @@ pub struct Consumer { } struct ConsumerInternal { + name: Option, client: Client, stream: String, offset_specification: OffsetSpecification, @@ -52,6 +55,7 @@ struct ConsumerInternal { waker: AtomicWaker, metrics_collector: Arc, filter_configuration: Option, + consumer_update_listener: Option, } impl ConsumerInternal { @@ -86,22 +90,17 @@ impl FilterConfiguration { } pub struct MessageContext { - consumer: Consumer, - subscriber_name: String, - reference: String, + consumer_name: Option, + stream: String, } impl MessageContext { - pub fn get_consumer(self) -> Consumer { - self.consumer + pub fn get_name(self) -> Option { + self.consumer_name } - pub fn get_subscriber_name(self) -> String { - self.subscriber_name - } - - pub fn get_reference(self) -> String { - self.reference + pub fn get_stream(self) -> String { + self.stream } } @@ -111,6 +110,7 @@ pub struct ConsumerBuilder { pub(crate) environment: Environment, pub(crate) offset_specification: OffsetSpecification, pub(crate) filter_configuration: Option, + pub(crate) consumer_update_listener: Option, pub(crate) client_provided_name: String, pub(crate) properties: HashMap, } @@ -172,6 +172,7 @@ impl ConsumerBuilder { let subscription_id = 1; let (tx, rx) = channel(10000); let consumer = Arc::new(ConsumerInternal { + name: self.consumer_name.clone(), subscription_id, stream: stream.to_string(), client: client.clone(), @@ -181,6 +182,7 @@ impl ConsumerBuilder { waker: AtomicWaker::new(), metrics_collector: collector, filter_configuration: self.filter_configuration.clone(), + consumer_update_listener: self.consumer_update_listener.clone(), }); let msg_handler = ConsumerMessageHandler(consumer.clone()); client.set_handler(msg_handler).await; @@ -213,7 +215,7 @@ impl ConsumerBuilder { if response.is_ok() { Ok(Consumer { - name: self.consumer_name, + name: self.consumer_name.clone(), receiver: rx, internal: consumer, }) @@ -245,6 +247,26 @@ impl ConsumerBuilder { self } + pub fn consumer_update( + mut self, + consumer_update_listener: impl Fn(u8, &MessageContext) -> OffsetSpecification + + Send + + Sync + + 'static, + ) -> Self { + let f = Arc::new(consumer_update_listener); + self.consumer_update_listener = Some(f); + self + } + + pub fn consumer_update_arc( + mut self, + consumer_update_listener: Option, + ) -> Self { + self.consumer_update_listener = consumer_update_listener; + self + } + pub fn properties(mut self, properties: HashMap) -> Self { self.properties = properties; self @@ -386,8 +408,41 @@ impl MessageHandler for ConsumerMessageHandler { // TODO handle credit fail let _ = self.0.client.credit(self.0.subscription_id, 1).await; self.0.metrics_collector.consume(len as u64).await; - } else { - println!("other message arrived"); + } else if let ResponseKind::ConsumerUpdate(consumer_update) = response.kind_ref() { + trace!("Received a ConsumerUpdate message"); + // If no callback is provided by the user we will restart from Next by protocol + // We need to respond to the server too + if self.0.consumer_update_listener.is_none() { + trace!("User defined callback is not provided"); + let offset_specification = OffsetSpecification::Next; + let _ = self + .0 + .client + .consumer_update( + consumer_update.get_correlation_id(), + offset_specification, + ) + .await; + } else { + // Otherwise the Offset specification is returned by the user callback + let is_active = consumer_update.is_active(); + let message_context = MessageContext { + consumer_name: self.0.name.clone(), + stream: self.0.stream.clone(), + }; + let consumer_update_listener_callback = + self.0.consumer_update_listener.clone().unwrap(); + let offset_specification = + consumer_update_listener_callback(is_active, &message_context); + let _ = self + .0 + .client + .consumer_update( + consumer_update.get_correlation_id(), + offset_specification, + ) + .await; + } } } Some(Err(err)) => { diff --git a/src/environment.rs b/src/environment.rs index 290415d..283bad4 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -74,6 +74,7 @@ impl Environment { environment: self.clone(), offset_specification: OffsetSpecification::Next, filter_configuration: None, + consumer_update_listener: None, client_provided_name: String::from("rust-stream-consumer"), properties: HashMap::new(), } @@ -84,6 +85,7 @@ impl Environment { environment: self.clone(), offset_specification: OffsetSpecification::Next, filter_configuration: None, + consumer_update_listener: None, client_provided_name: String::from("rust-super-stream-consumer"), properties: HashMap::new(), } diff --git a/src/lib.rs b/src/lib.rs index 2fe8c66..67d15e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -87,14 +87,16 @@ pub type RabbitMQStreamResult = Result; pub use crate::client::{Client, ClientOptions, MetricsCollector}; -pub use crate::consumer::{Consumer, ConsumerBuilder, ConsumerHandle, FilterConfiguration}; +pub use crate::consumer::{ + Consumer, ConsumerBuilder, ConsumerHandle, FilterConfiguration, MessageContext, +}; pub use crate::environment::{Environment, EnvironmentBuilder, TlsConfiguration}; pub use crate::producer::{Dedup, NoDedup, Producer, ProducerBuilder}; pub mod types { pub use crate::byte_capacity::ByteCapacity; pub use crate::client::{Broker, MessageResult, StreamMetadata}; - pub use crate::consumer::Delivery; + pub use crate::consumer::{Delivery, MessageContext}; pub use crate::offset_specification::OffsetSpecification; pub use crate::stream_creator::StreamCreator; pub use crate::superstream::HashRoutingMurmurStrategy; diff --git a/src/superstream_consumer.rs b/src/superstream_consumer.rs index b7c61e1..cdb299c 100644 --- a/src/superstream_consumer.rs +++ b/src/superstream_consumer.rs @@ -1,8 +1,10 @@ use crate::client::Client; -use crate::consumer::Delivery; +use crate::consumer::{ConsumerUpdateListener, Delivery}; use crate::error::{ConsumerCloseError, ConsumerDeliveryError}; use crate::superstream::DefaultSuperStreamMetadata; -use crate::{error::ConsumerCreateError, ConsumerHandle, Environment, FilterConfiguration}; +use crate::{ + error::ConsumerCreateError, ConsumerHandle, Environment, FilterConfiguration, MessageContext, +}; use futures::task::AtomicWaker; use futures::{Stream, StreamExt}; use rabbitmq_stream_protocol::commands::subscribe::OffsetSpecification; @@ -33,6 +35,7 @@ pub struct SuperStreamConsumerBuilder { pub(crate) environment: Environment, pub(crate) offset_specification: OffsetSpecification, pub(crate) filter_configuration: Option, + pub(crate) consumer_update_listener: Option, pub(crate) client_provided_name: String, pub(crate) properties: HashMap, } @@ -64,6 +67,7 @@ impl SuperStreamConsumerBuilder { .offset(self.offset_specification.clone()) .client_provided_name(self.client_provided_name.as_str()) .filter_input(self.filter_configuration.clone()) + .consumer_update_arc(self.consumer_update_listener.clone()) .properties(self.properties.clone()) .build(partition.as_str()) .await @@ -101,6 +105,18 @@ impl SuperStreamConsumerBuilder { self } + pub fn consumer_update( + mut self, + consumer_update_listener: impl Fn(u8, &MessageContext) -> OffsetSpecification + + Send + + Sync + + 'static, + ) -> Self { + let f = Arc::new(consumer_update_listener); + self.consumer_update_listener = Some(f); + self + } + pub fn client_provided_name(mut self, name: &str) -> Self { self.client_provided_name = String::from(name); self diff --git a/tests/integration/client_test.rs b/tests/integration/client_test.rs index 03c9a4c..4051c48 100644 --- a/tests/integration/client_test.rs +++ b/tests/integration/client_test.rs @@ -457,16 +457,3 @@ async fn client_test_route_test() { test.partitions.get(0).unwrap() ); } - -#[tokio::test(flavor = "multi_thread")] -async fn client_consumer_update_request_test() { - let test = TestClient::create().await; - - let response = test - .client - .consumer_update(OffsetSpecification::Next) - .await - .unwrap(); - - assert_eq!(&ResponseCode::Ok, response.code()); -}