diff --git a/quickwit/quickwit-common/Cargo.toml b/quickwit/quickwit-common/Cargo.toml
index 675522c7d9f..83170a8ec56 100644
--- a/quickwit/quickwit-common/Cargo.toml
+++ b/quickwit/quickwit-common/Cargo.toml
@@ -37,7 +37,7 @@ siphasher = { workspace = true }
tempfile = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
-tokio-metrics ={ workspace = true }
+tokio-metrics = { workspace = true }
tokio-stream = { workspace = true }
tonic = { workspace = true }
tower = { workspace = true }
@@ -51,3 +51,4 @@ named_tasks = ["tokio/tracing"]
serde_json = { workspace = true }
tempfile = { workspace = true }
proptest = { workspace = true }
+tokio = { workspace = true, features = ["test-util"] }
diff --git a/quickwit/quickwit-common/src/tower/circuit_breaker.rs b/quickwit/quickwit-common/src/tower/circuit_breaker.rs
new file mode 100644
index 00000000000..ae80516ae83
--- /dev/null
+++ b/quickwit/quickwit-common/src/tower/circuit_breaker.rs
@@ -0,0 +1,372 @@
+// Copyright (C) 2024 Quickwit, Inc.
+//
+// Quickwit is offered under the AGPL v3.0 and as commercial software.
+// For commercial licensing, contact us at hello@quickwit.io.
+//
+// AGPL:
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+use std::future::Future;
+use std::pin::Pin;
+use std::sync::{Arc, Mutex};
+use std::task::{Context, Poll};
+use std::time::Duration;
+
+use pin_project::pin_project;
+use prometheus::IntCounter;
+use tokio::time::Instant;
+use tower::{Layer, Service};
+
+/// The circuit breaker layer implements the [circuit breaker pattern](https://martinfowler.com/bliki/CircuitBreaker.html).
+///
+/// It counts the errors emitted by the inner service, and if the number of errors exceeds a certain
+/// threshold within a certain time window, it will "open" the circuit.
+///
+/// Requests will then be rejected for a given timeout.
+/// After this timeout, the circuit breaker ends up in a HalfOpen state. It will allow a single
+/// request to pass through. Depending on the result of this request, the circuit breaker will
+/// either close the circuit again or open it again.
+///
+/// Implementation detail:
+///
+/// A circuit breaker needs to have some logic to estimate the chances for the next request
+/// to fail. In this implementation, we use a simple heuristic that does not take in account
+/// successes. We simply count the number or errors which happened in the last window.
+///
+/// The circuit breaker does not attempt to measure accurately the error rate.
+/// Instead, it counts errors, and check for the time window in which these errors occurred.
+/// This approach is accurate enough, robust, very easy to code and avoids calling the
+/// `Instant::now()` at every error in the open state.
+#[derive(Debug, Clone)]
+pub struct CircuitBreakerLayer {
+ max_error_count_per_time_window: u32,
+ time_window: Duration,
+ timeout: Duration,
+ evaluator: Evaluator,
+ circuit_break_total: prometheus::IntCounter,
+}
+
+pub trait CircuitBreakerEvaluator: Clone {
+ type Response;
+ type Error;
+ fn is_circuit_breaker_error(&self, output: &Result) -> bool;
+ fn make_circuit_breaker_output(&self) -> Self::Error;
+ fn make_layer(
+ self,
+ max_num_errors_per_secs: u32,
+ timeout: Duration,
+ circuit_break_total: prometheus::IntCounter,
+ ) -> CircuitBreakerLayer {
+ CircuitBreakerLayer {
+ max_error_count_per_time_window: max_num_errors_per_secs,
+ time_window: Duration::from_secs(1),
+ timeout,
+ evaluator: self,
+ circuit_break_total,
+ }
+ }
+}
+
+impl Layer for CircuitBreakerLayer {
+ type Service = CircuitBreaker;
+
+ fn layer(&self, service: S) -> CircuitBreaker {
+ let time_window = Duration::from_millis(self.time_window.as_millis() as u64);
+ let timeout = Duration::from_millis(self.timeout.as_millis() as u64);
+ CircuitBreaker {
+ underlying: service,
+ circuit_breaker_inner: Arc::new(Mutex::new(CircuitBreakerInner {
+ max_error_count_per_time_window: self.max_error_count_per_time_window,
+ time_window,
+ timeout,
+ state: CircuitBreakerState::Closed(ClosedState {
+ error_counter: 0u32,
+ error_window_end: Instant::now() + time_window,
+ }),
+ evaluator: self.evaluator.clone(),
+ circuit_break_total: self.circuit_break_total.clone(),
+ })),
+ }
+ }
+}
+
+struct CircuitBreakerInner {
+ max_error_count_per_time_window: u32,
+ time_window: Duration,
+ timeout: Duration,
+ evaluator: Evaluator,
+ state: CircuitBreakerState,
+ circuit_break_total: IntCounter,
+}
+
+impl CircuitBreakerInner {
+ fn get_state(&mut self) -> CircuitBreakerState {
+ let new_state = match self.state {
+ CircuitBreakerState::Open { until } => {
+ let now = Instant::now();
+ if now < until {
+ CircuitBreakerState::Open { until }
+ } else {
+ CircuitBreakerState::HalfOpen
+ }
+ }
+ other => other,
+ };
+ self.state = new_state;
+ new_state
+ }
+
+ fn receive_error(&mut self) {
+ match self.state {
+ CircuitBreakerState::HalfOpen => {
+ self.circuit_break_total.inc();
+ self.state = CircuitBreakerState::Open {
+ until: Instant::now() + self.timeout,
+ }
+ }
+ CircuitBreakerState::Open { .. } => {}
+ CircuitBreakerState::Closed(ClosedState {
+ error_counter,
+ error_window_end,
+ }) => {
+ if error_counter < self.max_error_count_per_time_window {
+ self.state = CircuitBreakerState::Closed(ClosedState {
+ error_counter: error_counter + 1,
+ error_window_end,
+ });
+ return;
+ }
+ let now = Instant::now();
+ if now < error_window_end {
+ self.circuit_break_total.inc();
+ self.state = CircuitBreakerState::Open {
+ until: now + self.timeout,
+ };
+ } else {
+ self.state = CircuitBreakerState::Closed(ClosedState {
+ error_counter: 0u32,
+ error_window_end: now + self.time_window,
+ });
+ }
+ }
+ }
+ }
+
+ fn receive_success(&mut self) {
+ match self.state {
+ CircuitBreakerState::HalfOpen | CircuitBreakerState::Open { .. } => {
+ self.state = CircuitBreakerState::Closed(ClosedState {
+ error_counter: 0u32,
+ error_window_end: Instant::now() + self.time_window,
+ });
+ }
+ CircuitBreakerState::Closed { .. } => {
+ // We could actually take that as a signal.
+ }
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct CircuitBreaker {
+ underlying: S,
+ circuit_breaker_inner: Arc>>,
+}
+
+impl std::fmt::Debug for CircuitBreaker {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ f.debug_struct("CircuitBreaker").finish()
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+enum CircuitBreakerState {
+ Open { until: Instant },
+ HalfOpen,
+ Closed(ClosedState),
+}
+
+#[derive(Debug, Clone, Copy)]
+struct ClosedState {
+ error_counter: u32,
+ error_window_end: Instant,
+}
+
+impl Service for CircuitBreaker
+where
+ S: Service,
+ Evaluator: CircuitBreakerEvaluator,
+{
+ type Response = S::Response;
+ type Error = S::Error;
+ type Future = CircuitBreakerFuture;
+
+ fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ let mut inner = self.circuit_breaker_inner.lock().unwrap();
+ let state = inner.get_state();
+ match state {
+ CircuitBreakerState::Closed { .. } | CircuitBreakerState::HalfOpen => {
+ self.underlying.poll_ready(cx)
+ }
+ CircuitBreakerState::Open { .. } => {
+ let circuit_break_error = inner.evaluator.make_circuit_breaker_output();
+ Poll::Ready(Err(circuit_break_error))
+ }
+ }
+ }
+
+ fn call(&mut self, request: R) -> Self::Future {
+ CircuitBreakerFuture {
+ underlying_fut: self.underlying.call(request),
+ circuit_breaker_inner: self.circuit_breaker_inner.clone(),
+ }
+ }
+}
+
+#[pin_project]
+pub struct CircuitBreakerFuture {
+ #[pin]
+ underlying_fut: F,
+ circuit_breaker_inner: Arc>>,
+}
+
+impl Future for CircuitBreakerFuture
+where
+ F: Future