-
Notifications
You must be signed in to change notification settings - Fork 414
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
89ce762
commit e5a20a6
Showing
4 changed files
with
401 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,362 @@ | ||
// Copyright (C) 2024 Quickwit, Inc. | ||
// | ||
// Quickwit is offered under the AGPL v3.0 and as commercial software. | ||
// For commercial licensing, contact us at [email protected]. | ||
// | ||
// AGPL: | ||
// This program is free software: you can redistribute it and/or modify | ||
// it under the terms of the GNU Affero General Public License as | ||
// published by the Free Software Foundation, either version 3 of the | ||
// License, or (at your option) any later version. | ||
// | ||
// This program is distributed in the hope that it will be useful, | ||
// but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
// GNU Affero General Public License for more details. | ||
// | ||
// You should have received a copy of the GNU Affero General Public License | ||
// along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
use std::future::Future; | ||
use std::pin::Pin; | ||
use std::sync::{Arc, Mutex}; | ||
use std::task::{Context, Poll}; | ||
use std::time::Duration; | ||
|
||
use pin_project::pin_project; | ||
use tokio::time::Instant; | ||
use tower::{Layer, Service}; | ||
|
||
/// The circuit breaker layer implements the [circuit breaker pattern](https://martinfowler.com/bliki/CircuitBreaker.html). | ||
/// | ||
/// It counts the errors emitted by the inner service, and if the number of errors exceeds a certain | ||
/// threshold within a certain time window, it will "open" the circuit. | ||
/// | ||
/// Requests will then be rejected for a given timeout. | ||
/// After this timeout, the circuit breaker ends up in a HalfOpen state. It will allow a single | ||
/// request to pass through. Depending on the result of this request, the circuit breaker will | ||
/// either close the circuit again or open it again. | ||
/// | ||
/// Implementation detail: | ||
/// | ||
/// A circuit breaker needs to have some logic to estimate the chances for the next request | ||
/// to fail. In this implementation, we use a simple heuristic that does not take in account | ||
/// successes. We simply count the number or errors which happened in the last window. | ||
/// | ||
/// The circuit breaker does not attempt to measure accurately the error rate. | ||
/// Instead, it counts errors, and check for the time window in which these errors occurred. | ||
/// This approach is accurate enough, robust, very easy to code and avoids calling the | ||
/// `Instant::now()` at every error in the open state. | ||
#[derive(Debug, Clone, Copy)] | ||
pub struct CircuitBreakerLayer<Evaluator> { | ||
pub max_error_count_per_time_window: u32, | ||
pub time_window: Duration, | ||
pub timeout: Duration, | ||
pub evaluator: Evaluator, | ||
} | ||
|
||
pub trait CircuitBreakerEvaluator: Clone { | ||
type Response; | ||
type Error; | ||
fn is_circuit_breaker_error(&self, output: &Result<Self::Response, Self::Error>) -> bool; | ||
fn make_circuit_breaker_output(&self) -> Self::Error; | ||
fn make_layer( | ||
self, | ||
max_num_errors_per_secs: u32, | ||
timeout: Duration, | ||
) -> CircuitBreakerLayer<Self> { | ||
CircuitBreakerLayer { | ||
max_error_count_per_time_window: max_num_errors_per_secs, | ||
time_window: Duration::from_secs(1), | ||
timeout, | ||
evaluator: self, | ||
} | ||
} | ||
} | ||
|
||
impl<S, Evaluator: CircuitBreakerEvaluator> Layer<S> for CircuitBreakerLayer<Evaluator> { | ||
type Service = CircuitBreaker<S, Evaluator>; | ||
|
||
fn layer(&self, service: S) -> CircuitBreaker<S, Evaluator> { | ||
let time_window = Duration::from_millis(self.time_window.as_millis() as u64); | ||
let timeout = Duration::from_millis(self.timeout.as_millis() as u64); | ||
CircuitBreaker { | ||
underlying: service, | ||
circuit_breaker_inner: Arc::new(Mutex::new(CircuitBreakerInner { | ||
max_error_count_per_time_window: self.max_error_count_per_time_window, | ||
time_window, | ||
timeout, | ||
state: CircuitBreakerState::Closed(ClosedState { | ||
error_counter: 0u32, | ||
error_window_end: Instant::now() + time_window, | ||
}), | ||
evaluator: self.evaluator.clone(), | ||
})), | ||
} | ||
} | ||
} | ||
|
||
struct CircuitBreakerInner<Evaluator> { | ||
max_error_count_per_time_window: u32, | ||
time_window: Duration, | ||
timeout: Duration, | ||
evaluator: Evaluator, | ||
state: CircuitBreakerState, | ||
} | ||
|
||
impl<Evaluator> CircuitBreakerInner<Evaluator> { | ||
fn get_state(&mut self) -> CircuitBreakerState { | ||
let new_state = match self.state { | ||
CircuitBreakerState::Open { until } => { | ||
let now = Instant::now(); | ||
if now < until { | ||
CircuitBreakerState::Open { until } | ||
} else { | ||
CircuitBreakerState::HalfOpen | ||
} | ||
} | ||
other => other, | ||
}; | ||
self.state = new_state; | ||
new_state | ||
} | ||
|
||
fn receive_error(&mut self) { | ||
match self.state { | ||
CircuitBreakerState::HalfOpen => { | ||
self.state = CircuitBreakerState::Open { | ||
until: Instant::now() + self.timeout, | ||
} | ||
} | ||
CircuitBreakerState::Open { .. } => {} | ||
CircuitBreakerState::Closed(ClosedState { | ||
error_counter, | ||
error_window_end, | ||
}) => { | ||
if error_counter < self.max_error_count_per_time_window { | ||
self.state = CircuitBreakerState::Closed(ClosedState { | ||
error_counter: error_counter + 1, | ||
error_window_end, | ||
}); | ||
return; | ||
} | ||
let now = Instant::now(); | ||
if now < error_window_end { | ||
self.state = CircuitBreakerState::Open { | ||
until: now + self.timeout, | ||
}; | ||
} else { | ||
self.state = CircuitBreakerState::Closed(ClosedState { | ||
error_counter: 0u32, | ||
error_window_end: now + self.time_window, | ||
}); | ||
} | ||
} | ||
} | ||
} | ||
|
||
fn receive_success(&mut self) { | ||
match self.state { | ||
CircuitBreakerState::HalfOpen | CircuitBreakerState::Open { .. } => { | ||
self.state = CircuitBreakerState::Closed(ClosedState { | ||
error_counter: 0u32, | ||
error_window_end: Instant::now() + self.time_window, | ||
}); | ||
} | ||
CircuitBreakerState::Closed { .. } => { | ||
// We could actually take that as a signal. | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[derive(Clone)] | ||
pub struct CircuitBreaker<S, Evaluator> { | ||
underlying: S, | ||
circuit_breaker_inner: Arc<Mutex<CircuitBreakerInner<Evaluator>>>, | ||
} | ||
|
||
impl<S, Evaluator> std::fmt::Debug for CircuitBreaker<S, Evaluator> { | ||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | ||
f.debug_struct("CircuitBreaker").finish() | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone, Copy)] | ||
enum CircuitBreakerState { | ||
Open { until: Instant }, | ||
HalfOpen, | ||
Closed(ClosedState), | ||
} | ||
|
||
#[derive(Debug, Clone, Copy)] | ||
struct ClosedState { | ||
error_counter: u32, | ||
error_window_end: Instant, | ||
} | ||
|
||
impl<S, R, Evaluator> Service<R> for CircuitBreaker<S, Evaluator> | ||
where | ||
S: Service<R>, | ||
Evaluator: CircuitBreakerEvaluator<Response = S::Response, Error = S::Error>, | ||
{ | ||
type Response = S::Response; | ||
type Error = S::Error; | ||
type Future = CircuitBreakerFuture<S::Future, Evaluator>; | ||
|
||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { | ||
let mut inner = self.circuit_breaker_inner.lock().unwrap(); | ||
let state = inner.get_state(); | ||
match state { | ||
CircuitBreakerState::Closed { .. } | CircuitBreakerState::HalfOpen => { | ||
self.underlying.poll_ready(cx) | ||
} | ||
CircuitBreakerState::Open { .. } => { | ||
let circuit_break_error = inner.evaluator.make_circuit_breaker_output(); | ||
Poll::Ready(Err(circuit_break_error)) | ||
} | ||
} | ||
} | ||
|
||
fn call(&mut self, request: R) -> Self::Future { | ||
CircuitBreakerFuture { | ||
underlying_fut: self.underlying.call(request), | ||
circuit_breaker_inner: self.circuit_breaker_inner.clone(), | ||
} | ||
} | ||
} | ||
|
||
#[pin_project] | ||
pub struct CircuitBreakerFuture<F, Evaluator> { | ||
#[pin] | ||
underlying_fut: F, | ||
circuit_breaker_inner: Arc<Mutex<CircuitBreakerInner<Evaluator>>>, | ||
} | ||
|
||
impl<Response, Error, F, Evaluator> Future for CircuitBreakerFuture<F, Evaluator> | ||
where | ||
F: Future<Output = Result<Response, Error>>, | ||
Evaluator: CircuitBreakerEvaluator<Response = Response, Error = Error>, | ||
{ | ||
type Output = F::Output; | ||
|
||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||
let circuit_breaker_inner = self.circuit_breaker_inner.clone(); | ||
let poll_res = self.project().underlying_fut.poll(cx); | ||
match poll_res { | ||
Poll::Pending => Poll::Pending, | ||
Poll::Ready(result) => { | ||
let mut circuit_breaker_inner_lock = circuit_breaker_inner.lock().unwrap(); | ||
let is_circuit_breaker_error = circuit_breaker_inner_lock | ||
.evaluator | ||
.is_circuit_breaker_error(&result); | ||
if is_circuit_breaker_error { | ||
circuit_breaker_inner_lock.receive_error(); | ||
} else { | ||
circuit_breaker_inner_lock.receive_success(); | ||
} | ||
Poll::Ready(result) | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use std::sync::atomic::{AtomicBool, Ordering}; | ||
|
||
use tower::{ServiceBuilder, ServiceExt}; | ||
|
||
use super::*; | ||
|
||
#[derive(Debug)] | ||
enum TestError { | ||
CircuitBreak, | ||
ServiceError, | ||
} | ||
|
||
#[derive(Debug, Clone, Copy)] | ||
struct TestCircuitBreakerEvaluator; | ||
|
||
impl CircuitBreakerEvaluator for TestCircuitBreakerEvaluator { | ||
type Response = (); | ||
type Error = TestError; | ||
|
||
fn is_circuit_breaker_error(&self, output: &Result<Self::Response, Self::Error>) -> bool { | ||
output.is_err() | ||
} | ||
|
||
fn make_circuit_breaker_output(&self) -> TestError { | ||
TestError::CircuitBreak | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_circuit_breaker() { | ||
tokio::time::pause(); | ||
let test_switch: Arc<AtomicBool> = Arc::new(AtomicBool::new(true)); | ||
|
||
const TIMEOUT: Duration = Duration::from_millis(500); | ||
|
||
let mut service = ServiceBuilder::new() | ||
.layer(TestCircuitBreakerEvaluator.make_layer(10, TIMEOUT)) | ||
.service_fn(|_| async { | ||
if test_switch.load(Ordering::Relaxed) { | ||
Ok(()) | ||
} else { | ||
Err(TestError::ServiceError) | ||
} | ||
}); | ||
|
||
service.ready().await.unwrap().call(()).await.unwrap(); | ||
|
||
for _ in 0..1_000 { | ||
service.ready().await.unwrap().call(()).await.unwrap(); | ||
} | ||
|
||
test_switch.store(false, Ordering::Relaxed); | ||
|
||
let mut service_error_count = 0; | ||
let mut circuit_break_count = 0; | ||
for _ in 0..1_000 { | ||
match service.ready().await { | ||
Ok(service) => { | ||
service.call(()).await.unwrap_err(); | ||
service_error_count += 1; | ||
} | ||
Err(_circuit_breaker_error) => { | ||
circuit_break_count += 1; | ||
} | ||
} | ||
} | ||
|
||
assert_eq!(service_error_count + circuit_break_count, 1_000); | ||
assert_eq!(service_error_count, 11); | ||
|
||
tokio::time::advance(TIMEOUT).await; | ||
|
||
// The test request at half open fails. | ||
for _ in 0..1_000 { | ||
match service.ready().await { | ||
Ok(service) => { | ||
service.call(()).await.unwrap_err(); | ||
service_error_count += 1; | ||
} | ||
Err(_circuit_breaker_error) => { | ||
circuit_break_count += 1; | ||
} | ||
} | ||
} | ||
|
||
assert_eq!(service_error_count + circuit_break_count, 2_000); | ||
assert_eq!(service_error_count, 12); | ||
|
||
test_switch.store(true, Ordering::Relaxed); | ||
tokio::time::advance(TIMEOUT).await; | ||
|
||
// The test request at half open succeeds. | ||
for _ in 0..1_000 { | ||
service.ready().await.unwrap().call(()).await.unwrap(); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.