Skip to content

Commit

Permalink
Added a circuit breaker layer
Browse files Browse the repository at this point in the history
  • Loading branch information
fulmicoton committed Jun 18, 2024
1 parent 89ce762 commit e5a20a6
Show file tree
Hide file tree
Showing 4 changed files with 401 additions and 4 deletions.
3 changes: 2 additions & 1 deletion quickwit/quickwit-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ siphasher = { workspace = true }
tempfile = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio-metrics ={ workspace = true }
tokio-metrics = { workspace = true }
tokio-stream = { workspace = true }
tonic = { workspace = true }
tower = { workspace = true }
Expand All @@ -51,3 +51,4 @@ named_tasks = ["tokio/tracing"]
serde_json = { workspace = true }
tempfile = { workspace = true }
proptest = { workspace = true }
tokio = { workspace = true, features = ["test-util"] }
362 changes: 362 additions & 0 deletions quickwit/quickwit-common/src/tower/circuit_breaker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,362 @@
// Copyright (C) 2024 Quickwit, Inc.
//
// Quickwit is offered under the AGPL v3.0 and as commercial software.
// For commercial licensing, contact us at [email protected].
//
// AGPL:
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Duration;

use pin_project::pin_project;
use tokio::time::Instant;
use tower::{Layer, Service};

/// The circuit breaker layer implements the [circuit breaker pattern](https://martinfowler.com/bliki/CircuitBreaker.html).
///
/// It counts the errors emitted by the inner service, and if the number of errors exceeds a certain
/// threshold within a certain time window, it will "open" the circuit.
///
/// Requests will then be rejected for a given timeout.
/// After this timeout, the circuit breaker ends up in a HalfOpen state. It will allow a single
/// request to pass through. Depending on the result of this request, the circuit breaker will
/// either close the circuit again or open it again.
///
/// Implementation detail:
///
/// A circuit breaker needs to have some logic to estimate the chances for the next request
/// to fail. In this implementation, we use a simple heuristic that does not take in account
/// successes. We simply count the number or errors which happened in the last window.
///
/// The circuit breaker does not attempt to measure accurately the error rate.
/// Instead, it counts errors, and check for the time window in which these errors occurred.
/// This approach is accurate enough, robust, very easy to code and avoids calling the
/// `Instant::now()` at every error in the open state.
#[derive(Debug, Clone, Copy)]
pub struct CircuitBreakerLayer<Evaluator> {
pub max_error_count_per_time_window: u32,
pub time_window: Duration,
pub timeout: Duration,
pub evaluator: Evaluator,
}

pub trait CircuitBreakerEvaluator: Clone {
type Response;
type Error;
fn is_circuit_breaker_error(&self, output: &Result<Self::Response, Self::Error>) -> bool;
fn make_circuit_breaker_output(&self) -> Self::Error;
fn make_layer(
self,
max_num_errors_per_secs: u32,
timeout: Duration,
) -> CircuitBreakerLayer<Self> {
CircuitBreakerLayer {
max_error_count_per_time_window: max_num_errors_per_secs,
time_window: Duration::from_secs(1),
timeout,
evaluator: self,
}
}
}

impl<S, Evaluator: CircuitBreakerEvaluator> Layer<S> for CircuitBreakerLayer<Evaluator> {
type Service = CircuitBreaker<S, Evaluator>;

fn layer(&self, service: S) -> CircuitBreaker<S, Evaluator> {
let time_window = Duration::from_millis(self.time_window.as_millis() as u64);
let timeout = Duration::from_millis(self.timeout.as_millis() as u64);
CircuitBreaker {
underlying: service,
circuit_breaker_inner: Arc::new(Mutex::new(CircuitBreakerInner {
max_error_count_per_time_window: self.max_error_count_per_time_window,
time_window,
timeout,
state: CircuitBreakerState::Closed(ClosedState {
error_counter: 0u32,
error_window_end: Instant::now() + time_window,
}),
evaluator: self.evaluator.clone(),
})),
}
}
}

struct CircuitBreakerInner<Evaluator> {
max_error_count_per_time_window: u32,
time_window: Duration,
timeout: Duration,
evaluator: Evaluator,
state: CircuitBreakerState,
}

impl<Evaluator> CircuitBreakerInner<Evaluator> {
fn get_state(&mut self) -> CircuitBreakerState {
let new_state = match self.state {
CircuitBreakerState::Open { until } => {
let now = Instant::now();
if now < until {
CircuitBreakerState::Open { until }
} else {
CircuitBreakerState::HalfOpen
}
}
other => other,
};
self.state = new_state;
new_state
}

fn receive_error(&mut self) {
match self.state {
CircuitBreakerState::HalfOpen => {
self.state = CircuitBreakerState::Open {
until: Instant::now() + self.timeout,
}
}
CircuitBreakerState::Open { .. } => {}
CircuitBreakerState::Closed(ClosedState {
error_counter,
error_window_end,
}) => {
if error_counter < self.max_error_count_per_time_window {
self.state = CircuitBreakerState::Closed(ClosedState {
error_counter: error_counter + 1,
error_window_end,
});
return;
}
let now = Instant::now();
if now < error_window_end {
self.state = CircuitBreakerState::Open {
until: now + self.timeout,
};
} else {
self.state = CircuitBreakerState::Closed(ClosedState {
error_counter: 0u32,
error_window_end: now + self.time_window,
});
}
}
}
}

fn receive_success(&mut self) {
match self.state {
CircuitBreakerState::HalfOpen | CircuitBreakerState::Open { .. } => {
self.state = CircuitBreakerState::Closed(ClosedState {
error_counter: 0u32,
error_window_end: Instant::now() + self.time_window,
});
}
CircuitBreakerState::Closed { .. } => {
// We could actually take that as a signal.
}
}
}
}

#[derive(Clone)]
pub struct CircuitBreaker<S, Evaluator> {
underlying: S,
circuit_breaker_inner: Arc<Mutex<CircuitBreakerInner<Evaluator>>>,
}

impl<S, Evaluator> std::fmt::Debug for CircuitBreaker<S, Evaluator> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("CircuitBreaker").finish()
}
}

#[derive(Debug, Clone, Copy)]
enum CircuitBreakerState {
Open { until: Instant },
HalfOpen,
Closed(ClosedState),
}

#[derive(Debug, Clone, Copy)]
struct ClosedState {
error_counter: u32,
error_window_end: Instant,
}

impl<S, R, Evaluator> Service<R> for CircuitBreaker<S, Evaluator>
where
S: Service<R>,
Evaluator: CircuitBreakerEvaluator<Response = S::Response, Error = S::Error>,
{
type Response = S::Response;
type Error = S::Error;
type Future = CircuitBreakerFuture<S::Future, Evaluator>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut inner = self.circuit_breaker_inner.lock().unwrap();
let state = inner.get_state();
match state {
CircuitBreakerState::Closed { .. } | CircuitBreakerState::HalfOpen => {
self.underlying.poll_ready(cx)
}
CircuitBreakerState::Open { .. } => {
let circuit_break_error = inner.evaluator.make_circuit_breaker_output();
Poll::Ready(Err(circuit_break_error))
}
}
}

fn call(&mut self, request: R) -> Self::Future {
CircuitBreakerFuture {
underlying_fut: self.underlying.call(request),
circuit_breaker_inner: self.circuit_breaker_inner.clone(),
}
}
}

#[pin_project]
pub struct CircuitBreakerFuture<F, Evaluator> {
#[pin]
underlying_fut: F,
circuit_breaker_inner: Arc<Mutex<CircuitBreakerInner<Evaluator>>>,
}

impl<Response, Error, F, Evaluator> Future for CircuitBreakerFuture<F, Evaluator>
where
F: Future<Output = Result<Response, Error>>,
Evaluator: CircuitBreakerEvaluator<Response = Response, Error = Error>,
{
type Output = F::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let circuit_breaker_inner = self.circuit_breaker_inner.clone();
let poll_res = self.project().underlying_fut.poll(cx);
match poll_res {
Poll::Pending => Poll::Pending,
Poll::Ready(result) => {
let mut circuit_breaker_inner_lock = circuit_breaker_inner.lock().unwrap();
let is_circuit_breaker_error = circuit_breaker_inner_lock
.evaluator
.is_circuit_breaker_error(&result);
if is_circuit_breaker_error {
circuit_breaker_inner_lock.receive_error();
} else {
circuit_breaker_inner_lock.receive_success();
}
Poll::Ready(result)
}
}
}
}

#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};

use tower::{ServiceBuilder, ServiceExt};

use super::*;

#[derive(Debug)]
enum TestError {
CircuitBreak,
ServiceError,
}

#[derive(Debug, Clone, Copy)]
struct TestCircuitBreakerEvaluator;

impl CircuitBreakerEvaluator for TestCircuitBreakerEvaluator {
type Response = ();
type Error = TestError;

fn is_circuit_breaker_error(&self, output: &Result<Self::Response, Self::Error>) -> bool {
output.is_err()
}

fn make_circuit_breaker_output(&self) -> TestError {
TestError::CircuitBreak
}
}

#[tokio::test]
async fn test_circuit_breaker() {
tokio::time::pause();
let test_switch: Arc<AtomicBool> = Arc::new(AtomicBool::new(true));

const TIMEOUT: Duration = Duration::from_millis(500);

let mut service = ServiceBuilder::new()
.layer(TestCircuitBreakerEvaluator.make_layer(10, TIMEOUT))
.service_fn(|_| async {
if test_switch.load(Ordering::Relaxed) {
Ok(())
} else {
Err(TestError::ServiceError)
}
});

service.ready().await.unwrap().call(()).await.unwrap();

for _ in 0..1_000 {
service.ready().await.unwrap().call(()).await.unwrap();
}

test_switch.store(false, Ordering::Relaxed);

let mut service_error_count = 0;
let mut circuit_break_count = 0;
for _ in 0..1_000 {
match service.ready().await {
Ok(service) => {
service.call(()).await.unwrap_err();
service_error_count += 1;
}
Err(_circuit_breaker_error) => {
circuit_break_count += 1;
}
}
}

assert_eq!(service_error_count + circuit_break_count, 1_000);
assert_eq!(service_error_count, 11);

tokio::time::advance(TIMEOUT).await;

// The test request at half open fails.
for _ in 0..1_000 {
match service.ready().await {
Ok(service) => {
service.call(()).await.unwrap_err();
service_error_count += 1;
}
Err(_circuit_breaker_error) => {
circuit_break_count += 1;
}
}
}

assert_eq!(service_error_count + circuit_break_count, 2_000);
assert_eq!(service_error_count, 12);

test_switch.store(true, Ordering::Relaxed);
tokio::time::advance(TIMEOUT).await;

// The test request at half open succeeds.
for _ in 0..1_000 {
service.ready().await.unwrap().call(()).await.unwrap();
}
}
}
2 changes: 2 additions & 0 deletions quickwit/quickwit-common/src/tower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod box_layer;
mod box_service;
mod buffer;
mod change;
mod circuit_breaker;
mod delay;
mod estimate_rate;
mod event_listener;
Expand All @@ -41,6 +42,7 @@ pub use box_layer::BoxLayer;
pub use box_service::BoxService;
pub use buffer::{Buffer, BufferError, BufferLayer};
pub use change::Change;
pub use circuit_breaker::{CircuitBreaker, CircuitBreakerEvaluator, CircuitBreakerLayer};
pub use delay::{Delay, DelayLayer};
pub use estimate_rate::{EstimateRate, EstimateRateLayer};
pub use event_listener::{EventListener, EventListenerLayer};
Expand Down
Loading

0 comments on commit e5a20a6

Please sign in to comment.