tower/circuit_breaker/
service.rs1use std::{
2 sync::{Arc, Mutex},
3 task::{Context, Poll},
4};
5
6use tower_service::Service;
7
8use super::{future::ResponseFuture, policy::CircuitPolicy};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum CircuitStatus {
13 Closed,
15 Open,
17 HalfOpen,
19}
20
21#[derive(Debug)]
23pub enum CircuitError<E> {
24 Open,
26 Inner(E),
28}
29
30impl<E: std::fmt::Display> std::fmt::Display for CircuitError<E> {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Self::Open => write!(f, "circuit breaker is open"),
34 Self::Inner(e) => write!(f, "{e}"),
35 }
36 }
37}
38
39impl<E: std::error::Error + 'static> std::error::Error for CircuitError<E> {
40 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
41 match self {
42 Self::Inner(e) => Some(e),
43 Self::Open => None,
44 }
45 }
46}
47
48pub(crate) struct SharedState<P> {
50 pub(crate) status: CircuitStatus,
51 pub(crate) policy: P,
52}
53
54#[derive(Clone)]
79pub struct CircuitBreaker<S, P> {
80 inner: S,
81 pub(crate) shared: Arc<Mutex<SharedState<P>>>,
82}
83
84impl<S, P: CircuitPolicy> CircuitBreaker<S, P> {
85 pub fn new(inner: S, policy: P) -> Self {
87 Self {
88 inner,
89 shared: Arc::new(Mutex::new(SharedState {
90 status: CircuitStatus::Closed,
91 policy,
92 })),
93 }
94 }
95
96 pub fn status(&self) -> CircuitStatus {
98 self.shared
99 .lock()
100 .expect("circuit breaker state poisoned")
101 .status
102 .clone()
103 }
104
105 pub fn reset(&self) {
111 let mut s = self.shared.lock().expect("circuit breaker state poisoned");
112 s.policy.on_half_open(); s.status = CircuitStatus::Closed;
114 }
115}
116
117impl<S, P, Request> Service<Request> for CircuitBreaker<S, P>
118where
119 S: Service<Request>,
120 P: CircuitPolicy,
121{
122 type Response = S::Response;
123 type Error = CircuitError<S::Error>;
124 type Future = ResponseFuture<S::Future, S::Response, S::Error, P>;
125
126 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
127 {
128 let mut s = self.shared.lock().expect("circuit breaker state poisoned");
129 if s.status == CircuitStatus::Open {
130 if s.policy.should_probe() {
131 s.policy.on_half_open();
132 s.status = CircuitStatus::HalfOpen;
133 } else {
135 return Poll::Ready(Err(CircuitError::Open));
136 }
137 }
138 }
139
140 self.inner.poll_ready(cx).map_err(CircuitError::Inner)
141 }
142
143 fn call(&mut self, req: Request) -> Self::Future {
144 ResponseFuture::new(self.shared.clone(), self.inner.call(req))
145 }
146}
147
148#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::circuit_breaker::{CircuitBreakerLayer, ConsecutiveFailures};
154 use std::time::Duration;
155 use tower::{ServiceBuilder, ServiceExt};
156
157 #[tokio::test]
158 async fn closed_passes_requests_through() {
159 let mut svc = ServiceBuilder::new()
160 .layer(CircuitBreakerLayer::new(5, 0.8, Duration::from_secs(60)))
161 .service_fn(|req: &'static str| async move { Ok::<_, &'static str>(req) });
162
163 let resp = svc.ready().await.unwrap().call("hello").await;
164 assert!(resp.is_ok());
165 }
166
167 #[tokio::test]
168 async fn opens_after_failure_threshold() {
169 let mut svc = ServiceBuilder::new()
170 .layer(CircuitBreakerLayer::new(3, 0.8, Duration::from_secs(60)))
171 .service_fn(|_: &'static str| async move { Err::<&str, _>("fail") });
172
173 for _ in 0..3 {
174 let _ = svc.ready().await.unwrap().call("req").await;
175 }
176
177 let result = svc.ready().await;
179 assert!(matches!(result, Err(CircuitError::Open)));
180 }
181
182 #[tokio::test]
183 async fn manual_reset_closes_circuit() {
184 let inner = tower::service_fn(|_: &'static str| async move { Err::<&str, _>("fail") });
185 let policy = ConsecutiveFailures::new(2, 0.8, Duration::from_secs(60));
186 let cb = CircuitBreaker::new(inner, policy);
187
188 let _ = tower::ServiceExt::oneshot(cb.clone(), "req").await;
189 let _ = tower::ServiceExt::oneshot(cb.clone(), "req").await;
190 assert_eq!(cb.status(), CircuitStatus::Open);
191
192 cb.reset();
193 assert_eq!(cb.status(), CircuitStatus::Closed);
194 }
195
196 #[tokio::test]
197 async fn custom_policy_is_accepted() {
198 #[derive(Clone)]
200 struct AlwaysOpen;
201 impl CircuitPolicy for AlwaysOpen {
202 fn on_success(&mut self) -> bool { false }
203 fn on_failure(&mut self) -> bool { true }
204 fn should_probe(&self) -> bool { false }
205 fn on_half_open(&mut self) {}
206 }
207
208 let inner = tower::service_fn(|_: &'static str| async move { Err::<&str, _>("x") });
209 let cb = CircuitBreaker::new(inner, AlwaysOpen);
210
211 let _ = tower::ServiceExt::oneshot(cb.clone(), "req").await;
213 assert_eq!(cb.status(), CircuitStatus::Open);
214 }
215}