From d9d01e511b987c12ac5b2181fc271a22d771180c Mon Sep 17 00:00:00 2001 From: John Yang Date: Tue, 30 Apr 2024 14:50:10 -0700 Subject: [PATCH] update accumulator sig to return Result instead of TResult --- .../src/processing/strategies/reduce.rs | 154 +++++++++++++++--- 1 file changed, 132 insertions(+), 22 deletions(-) diff --git a/rust-arroyo/src/processing/strategies/reduce.rs b/rust-arroyo/src/processing/strategies/reduce.rs index d738b40a..25c0cc87 100644 --- a/rust-arroyo/src/processing/strategies/reduce.rs +++ b/rust-arroyo/src/processing/strategies/reduce.rs @@ -14,7 +14,9 @@ use super::InvalidMessage; struct BatchState { value: Option, - accumulator: Arc TResult + Send + Sync>, + accumulator: Arc< + dyn Fn(TResult, Message) -> Result, TResult)> + Send + Sync, + >, offsets: BTreeMap, batch_start_time: Deadline, message_count: usize, @@ -24,7 +26,9 @@ struct BatchState { impl BatchState { fn new( initial_value: TResult, - accumulator: Arc TResult + Send + Sync>, + accumulator: Arc< + dyn Fn(TResult, Message) -> Result, TResult)> + Send + Sync, + >, max_batch_time: Duration, compute_batch_size: fn(&T) -> usize, ) -> BatchState { @@ -38,21 +42,32 @@ impl BatchState { } } - fn add(&mut self, message: Message) { + fn add(&mut self, message: Message) -> Result<(), SubmitError> { for (partition, offset) in message.committable() { self.offsets.insert(partition, offset); } let tmp = self.value.take().unwrap(); - let payload = message.into_payload(); - self.message_count += (self.compute_batch_size)(&payload); - self.value = Some((self.accumulator)(tmp, payload)); + let message_count = (self.compute_batch_size)(&message.payload()); + match (self.accumulator)(tmp, message) { + Ok(result) => { + self.value = Some(result); + self.message_count += message_count; + Ok(()) + } + Err((submit_error, prev_result)) => { + self.value = Some(prev_result); + Err(submit_error) + } + } } } pub struct Reduce { next_step: Box>, - accumulator: Arc TResult + Send + Sync>, + accumulator: Arc< + dyn Fn(TResult, Message) -> Result, TResult)> + Send + Sync, + >, initial_value: Arc TResult + Send + Sync>, max_batch_size: usize, max_batch_time: Duration, @@ -79,7 +94,7 @@ impl ProcessingStrategy for Reduce ProcessingStrategy for Reduce Reduce { pub fn new( next_step: N, - accumulator: Arc TResult + Send + Sync>, + accumulator: Arc< + dyn Fn(TResult, Message) -> Result, TResult)> + Send + Sync, + >, initial_value: Arc TResult + Send + Sync>, max_batch_size: usize, max_batch_time: Duration, @@ -216,7 +233,7 @@ impl Reduce { mod tests { use crate::processing::strategies::reduce::Reduce; use crate::processing::strategies::{ - CommitRequest, ProcessingStrategy, StrategyError, SubmitError, + CommitRequest, MessageRejected, ProcessingStrategy, StrategyError, SubmitError, }; use crate::types::{BrokerMessage, InnerMessage, Message, Partition, Topic}; use std::sync::{Arc, Mutex}; @@ -256,9 +273,9 @@ mod tests { let max_batch_time = Duration::from_secs(1); let initial_value = Vec::new(); - let accumulator = Arc::new(|mut acc: Vec, value: u64| { - acc.push(value); - acc + let accumulator = Arc::new(|mut acc: Vec, msg: Message| { + acc.push(msg.into_payload()); + Ok(acc) }); let compute_batch_size = |_: &_| -> usize { 1 }; @@ -302,6 +319,99 @@ mod tests { ); } + #[test] + fn test_reduce_with_backpressure() { + let submitted_messages = Arc::new(Mutex::new(Vec::new())); + let submitted_messages_clone = submitted_messages.clone(); + + let partition1 = Partition::new(Topic::new("test"), 0); + + let max_batch_size = 2; + let max_batch_time = Duration::from_secs(1); + + #[derive(Clone, Debug, PartialEq)] + struct Buffer { + data: Vec, + flushed: bool, + } + + let initial_value = Buffer { + data: Vec::new(), + flushed: false, + }; + + let accumulator = Arc::new(move |mut acc: Buffer, msg: Message| { + if acc.flushed { + acc.data.push(msg.into_payload()); + acc.flushed = false; + Ok(acc) + } else { + acc.flushed = true; + Err(( + SubmitError::MessageRejected(MessageRejected { message: msg }), + acc, + )) + } + }); + let compute_batch_size = |_: &_| -> usize { 1 }; + + let next_step = NextStep { + submitted: submitted_messages, + }; + + let mut strategy = Reduce::new( + next_step, + accumulator, + Arc::new(move || initial_value.clone()), + max_batch_size, + max_batch_time, + compute_batch_size, + ); + + for i in 0..3 { + let msg = Message { + inner_message: InnerMessage::BrokerMessage(BrokerMessage::new( + i, + partition1, + i, + chrono::Utc::now(), + )), + }; + let res = strategy.submit(msg); + match res { + Err(SubmitError::MessageRejected(MessageRejected { message })) => { + strategy.submit(message).unwrap(); + } + _ => { + unreachable!("Strategy should have backpressured") + } + }; + let _ = strategy.poll(); + } + + // 3 messages with a max batch size of 2 means 1 batch was cleared + // and 1 message is left before next size limit. + assert_eq!(strategy.batch_state.message_count, 1); + + strategy.close(); + let _ = strategy.join(None); + + // 2 batches were created + assert_eq!( + *submitted_messages_clone.lock().unwrap(), + vec![ + Buffer { + data: vec![0, 1], + flushed: false + }, + Buffer { + data: vec![2], + flushed: false + } + ] + ); + } + #[test] fn test_reduce_with_custom_batch_size() { let submitted_messages = Arc::new(Mutex::new(Vec::new())); @@ -313,9 +423,9 @@ mod tests { let max_batch_time = Duration::from_secs(1); let initial_value = Vec::new(); - let accumulator = Arc::new(|mut acc: Vec, value: u64| { - acc.push(value); - acc + let accumulator = Arc::new(|mut acc: Vec, msg: Message| { + acc.push(msg.into_payload()); + Ok(acc) }); let compute_batch_size = |_: &_| -> usize { 5 }; @@ -370,9 +480,9 @@ mod tests { let max_batch_time = Duration::from_secs(100); let initial_value = Vec::new(); - let accumulator = Arc::new(|mut acc: Vec, value: u64| { - acc.push(value); - acc + let accumulator = Arc::new(|mut acc: Vec, msg: Message| { + acc.push(msg.into_payload()); + Ok(acc) }); let compute_batch_size = |_: &_| -> usize { 0 }; @@ -424,9 +534,9 @@ mod tests { let max_batch_time = Duration::from_secs(100); let initial_value = Vec::new(); - let accumulator = Arc::new(|mut acc: Vec, value: u64| { - acc.push(value); - acc + let accumulator = Arc::new(|mut acc: Vec, msg: Message| { + acc.push(msg.into_payload()); + Ok(acc) }); let compute_batch_size = |_: &_| -> usize { 0 };