Skip to main content

tower/circuit_breaker/
service.rs

1use std::{
2    sync::{Arc, Mutex},
3    task::{Context, Poll},
4};
5
6use tower_service::Service;
7
8use super::{future::ResponseFuture, policy::CircuitPolicy};
9
10/// Current state of a [`CircuitBreaker`] service.
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum CircuitStatus {
13    /// Normal operation — requests flow through.
14    Closed,
15    /// Service is unhealthy — requests are rejected immediately.
16    Open,
17    /// One probe request is allowed through to test recovery.
18    HalfOpen,
19}
20
21/// Error type returned by a [`CircuitBreaker`] service.
22#[derive(Debug)]
23pub enum CircuitError<E> {
24    /// The circuit is open; the inner service was not called.
25    Open,
26    /// The inner service returned this error.
27    Inner(E),
28}
29
30impl<E: std::fmt::Display> std::fmt::Display for CircuitError<E> {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Self::Open => write!(f, "circuit breaker is open"),
34            Self::Inner(e) => write!(f, "{e}"),
35        }
36    }
37}
38
39impl<E: std::error::Error + 'static> std::error::Error for CircuitError<E> {
40    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
41        match self {
42            Self::Inner(e) => Some(e),
43            Self::Open => None,
44        }
45    }
46}
47
48/// Shared mutable state between a [`CircuitBreaker`] and its [`ResponseFuture`].
49pub(crate) struct SharedState<P> {
50    pub(crate) status: CircuitStatus,
51    pub(crate) policy: P,
52}
53
54/// Tower [`Service`] implementing the circuit-breaker pattern.
55///
56/// The open/probe/close criteria are driven by a [`CircuitPolicy`], making
57/// the triggering logic independently customisable.  The built-in policy is
58/// [`ConsecutiveFailures`]; supply any type implementing [`CircuitPolicy`]
59/// via [`CircuitBreaker::new`] or [`CircuitBreakerLayer::with_policy`] for
60/// custom strategies.
61///
62/// # Thread safety
63///
64/// `CircuitBreaker<S, P>` is [`Send`] when both `S` and `P` are [`Send`].
65/// This is enforced structurally: the policy is held behind
66/// `Arc<Mutex<P>>`, so `Arc<Mutex<P>>: Send` requires `P: Send`.
67/// No explicit bound is placed on `P` in the [`Service`] impl, so
68/// `!Send` policies can still be used in single-threaded contexts without
69/// a compile error.  `P: Sync` is never required.
70///
71/// See [`CircuitPolicy`] for more detail.
72///
73/// See the [module documentation](super) for a full example.
74///
75/// [`ConsecutiveFailures`]: super::ConsecutiveFailures
76/// [`CircuitBreakerLayer::with_policy`]: super::CircuitBreakerLayer::with_policy
77/// [`CircuitPolicy`]: super::CircuitPolicy
78#[derive(Clone)]
79pub struct CircuitBreaker<S, P> {
80    inner: S,
81    pub(crate) shared: Arc<Mutex<SharedState<P>>>,
82}
83
84impl<S, P: CircuitPolicy> CircuitBreaker<S, P> {
85    /// Wrap `inner` with the given [`CircuitPolicy`].
86    pub fn new(inner: S, policy: P) -> Self {
87        Self {
88            inner,
89            shared: Arc::new(Mutex::new(SharedState {
90                status: CircuitStatus::Closed,
91                policy,
92            })),
93        }
94    }
95
96    /// Return the current [`CircuitStatus`].
97    pub fn status(&self) -> CircuitStatus {
98        self.shared
99            .lock()
100            .expect("circuit breaker state poisoned")
101            .status
102            .clone()
103    }
104
105    /// Manually close the circuit (e.g. after operator confirmation that the
106    /// backend is healthy).
107    ///
108    /// Calls [`CircuitPolicy::on_half_open`] to reset any per-window counters
109    /// in the policy, then sets the status to [`Closed`][CircuitStatus::Closed].
110    pub fn reset(&self) {
111        let mut s = self.shared.lock().expect("circuit breaker state poisoned");
112        s.policy.on_half_open(); // reuse the window-clear hook
113        s.status = CircuitStatus::Closed;
114    }
115}
116
117impl<S, P, Request> Service<Request> for CircuitBreaker<S, P>
118where
119    S: Service<Request>,
120    P: CircuitPolicy,
121{
122    type Response = S::Response;
123    type Error = CircuitError<S::Error>;
124    type Future = ResponseFuture<S::Future, S::Response, S::Error, P>;
125
126    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
127        {
128            let mut s = self.shared.lock().expect("circuit breaker state poisoned");
129            if s.status == CircuitStatus::Open {
130                if s.policy.should_probe() {
131                    s.policy.on_half_open();
132                    s.status = CircuitStatus::HalfOpen;
133                    // fall through to delegate to inner service
134                } else {
135                    return Poll::Ready(Err(CircuitError::Open));
136                }
137            }
138        }
139
140        self.inner.poll_ready(cx).map_err(CircuitError::Inner)
141    }
142
143    fn call(&mut self, req: Request) -> Self::Future {
144        ResponseFuture::new(self.shared.clone(), self.inner.call(req))
145    }
146}
147
148// ===== Tests =====
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use crate::circuit_breaker::{CircuitBreakerLayer, ConsecutiveFailures};
154    use std::time::Duration;
155    use tower::{ServiceBuilder, ServiceExt};
156
157    #[tokio::test]
158    async fn closed_passes_requests_through() {
159        let mut svc = ServiceBuilder::new()
160            .layer(CircuitBreakerLayer::new(5, 0.8, Duration::from_secs(60)))
161            .service_fn(|req: &'static str| async move { Ok::<_, &'static str>(req) });
162
163        let resp = svc.ready().await.unwrap().call("hello").await;
164        assert!(resp.is_ok());
165    }
166
167    #[tokio::test]
168    async fn opens_after_failure_threshold() {
169        let mut svc = ServiceBuilder::new()
170            .layer(CircuitBreakerLayer::new(3, 0.8, Duration::from_secs(60)))
171            .service_fn(|_: &'static str| async move { Err::<&str, _>("fail") });
172
173        for _ in 0..3 {
174            let _ = svc.ready().await.unwrap().call("req").await;
175        }
176
177        // Circuit is now Open — poll_ready should reject.
178        let result = svc.ready().await;
179        assert!(matches!(result, Err(CircuitError::Open)));
180    }
181
182    #[tokio::test]
183    async fn manual_reset_closes_circuit() {
184        let inner = tower::service_fn(|_: &'static str| async move { Err::<&str, _>("fail") });
185        let policy = ConsecutiveFailures::new(2, 0.8, Duration::from_secs(60));
186        let cb = CircuitBreaker::new(inner, policy);
187
188        let _ = tower::ServiceExt::oneshot(cb.clone(), "req").await;
189        let _ = tower::ServiceExt::oneshot(cb.clone(), "req").await;
190        assert_eq!(cb.status(), CircuitStatus::Open);
191
192        cb.reset();
193        assert_eq!(cb.status(), CircuitStatus::Closed);
194    }
195
196    #[tokio::test]
197    async fn custom_policy_is_accepted() {
198        // Verify the Service impl compiles and runs with a hand-rolled policy.
199        #[derive(Clone)]
200        struct AlwaysOpen;
201        impl CircuitPolicy for AlwaysOpen {
202            fn on_success(&mut self) -> bool { false }
203            fn on_failure(&mut self) -> bool { true }
204            fn should_probe(&self) -> bool { false }
205            fn on_half_open(&mut self) {}
206        }
207
208        let inner = tower::service_fn(|_: &'static str| async move { Err::<&str, _>("x") });
209        let cb = CircuitBreaker::new(inner, AlwaysOpen);
210
211        // One failure should open the circuit.
212        let _ = tower::ServiceExt::oneshot(cb.clone(), "req").await;
213        assert_eq!(cb.status(), CircuitStatus::Open);
214    }
215}