Skip to content

Commit aaf5608

Browse files
committed
Add options to customise retry policy
1 parent f97efce commit aaf5608

File tree

4 files changed

+105
-45
lines changed

4 files changed

+105
-45
lines changed

src/client/builder.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020-2024 Ben Ashford
2+
* Copyright 2020-2025 Ben Ashford
33
*
44
* Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
55
* http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
@@ -11,7 +11,7 @@
1111
use std::sync::Arc;
1212
use std::time::Duration;
1313

14-
use crate::error;
14+
use crate::{error, reconnect::ReconnectOptions};
1515

1616
#[derive(Debug)]
1717
/// Connection builder
@@ -20,6 +20,7 @@ pub struct ConnectionBuilder {
2020
pub(crate) port: u16,
2121
pub(crate) username: Option<Arc<str>>,
2222
pub(crate) password: Option<Arc<str>>,
23+
pub(crate) reconnect_options: ReconnectOptions,
2324
#[cfg(feature = "tls")]
2425
pub(crate) tls: bool,
2526
pub(crate) socket_keepalive: Option<Duration>,
@@ -36,6 +37,7 @@ impl ConnectionBuilder {
3637
port,
3738
username: None,
3839
password: None,
40+
reconnect_options: ReconnectOptions::default(),
3941
#[cfg(feature = "tls")]
4042
tls: false,
4143
socket_keepalive: Some(DEFAULT_KEEPALIVE),
@@ -72,4 +74,16 @@ impl ConnectionBuilder {
7274
self.socket_timeout = duration;
7375
self
7476
}
77+
78+
/// Set the reconnect timeout
79+
pub fn reconnect_timeout(&mut self, duration: Duration) -> &mut Self {
80+
self.reconnect_options.connection_timeout = duration;
81+
self
82+
}
83+
84+
/// Set the number of reconnection attempts
85+
pub fn reconnect_attempts(&mut self, attempts: u64) -> &mut Self {
86+
self.reconnect_options.max_connection_attempts = attempts;
87+
self
88+
}
7589
}

src/client/paired.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2017-2024 Ben Ashford
2+
* Copyright 2017-2025 Ben Ashford
33
*
44
* Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
55
* http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
@@ -264,7 +264,7 @@ impl ConnectionBuilder {
264264
Box::pin(con_f) as Pin<Box<dyn Future<Output = Result<_, error::Error>> + Send + Sync>>
265265
};
266266

267-
let reconnecting_con = reconnect(work_fn, conn_fn);
267+
let reconnecting_con = reconnect(work_fn, conn_fn, self.reconnect_options);
268268
reconnecting_con.map_ok(|con| PairedConnection {
269269
out_tx_c: Arc::new(con),
270270
})

src/client/pubsub/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2017-2024 Ben Ashford
2+
* Copyright 2017-2025 Ben Ashford
33
*
44
* Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
55
* http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
@@ -118,6 +118,7 @@ impl ConnectionBuilder {
118118
);
119119
Box::pin(con_f)
120120
},
121+
self.reconnect_options,
121122
);
122123
reconnecting_f.map_ok(|con| PubsubConnection {
123124
out_tx_c: Arc::new(con),

src/reconnect.rs

Lines changed: 85 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2018-2020 Ben Ashford
2+
* Copyright 2018-2025 Ben Ashford
33
*
44
* Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
55
* http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
@@ -28,10 +28,31 @@ type WorkFn<T, A> = dyn Fn(&T, A) -> Result<(), error::Error> + Send + Sync;
2828
type ConnFn<T> =
2929
dyn Fn() -> Pin<Box<dyn Future<Output = Result<T, error::Error>> + Send + Sync>> + Send + Sync;
3030

31+
const CONNECTION_TIMEOUT_SECONDS: u64 = 1;
32+
const MAX_CONNECTION_ATTEMPTS: u64 = 10;
33+
const CONNECTION_TIMEOUT: Duration = Duration::from_secs(CONNECTION_TIMEOUT_SECONDS);
34+
35+
#[derive(Debug, Copy, Clone)]
36+
pub(crate) struct ReconnectOptions {
37+
pub(crate) connection_timeout: Duration,
38+
pub(crate) max_connection_attempts: u64,
39+
}
40+
41+
impl Default for ReconnectOptions {
42+
#[inline]
43+
fn default() -> Self {
44+
ReconnectOptions {
45+
connection_timeout: CONNECTION_TIMEOUT,
46+
max_connection_attempts: MAX_CONNECTION_ATTEMPTS,
47+
}
48+
}
49+
}
50+
3151
struct ReconnectInner<A, T> {
3252
state: Mutex<ReconnectState<T>>,
3353
work_fn: Box<WorkFn<T, A>>,
3454
conn_fn: Box<ConnFn<T>>,
55+
reconnect_options: ReconnectOptions,
3556
}
3657

3758
impl<A, T> fmt::Debug for ReconnectInner<A, T> {
@@ -62,7 +83,11 @@ impl<A, T> Clone for Reconnect<A, T> {
6283
}
6384
}
6485

65-
pub(crate) async fn reconnect<A, T, W, C>(w: W, c: C) -> Result<Reconnect<A, T>, error::Error>
86+
pub(crate) async fn reconnect<A, T, W, C>(
87+
w: W,
88+
c: C,
89+
options: ReconnectOptions,
90+
) -> Result<Reconnect<A, T>, error::Error>
6691
where
6792
A: Send + 'static,
6893
W: Fn(&T, A) -> Result<(), error::Error> + Send + Sync + 'static,
@@ -77,6 +102,8 @@ where
77102

78103
work_fn: Box::new(w),
79104
conn_fn: Box::new(c),
105+
106+
reconnect_options: options,
80107
}));
81108
let rf = {
82109
let state = r.0.state.lock().expect("Poisoned lock");
@@ -97,19 +124,14 @@ impl<T> fmt::Debug for ReconnectState<T> {
97124
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
98125
write!(f, "ReconnectState::")?;
99126
match self {
100-
NotConnected => write!(f, "NotConnected"),
101-
Connected(_) => write!(f, "Connected"),
102-
ConnectionFailed(_) => write!(f, "ConnectionFailed"),
103-
Connecting => write!(f, "Connecting"),
127+
Self::NotConnected => write!(f, "NotConnected"),
128+
Self::Connected(_) => write!(f, "Connected"),
129+
Self::ConnectionFailed(_) => write!(f, "ConnectionFailed"),
130+
Self::Connecting => write!(f, "Connecting"),
104131
}
105132
}
106133
}
107134

108-
use self::ReconnectState::*;
109-
110-
const CONNECTION_TIMEOUT_SECONDS: u64 = 10;
111-
const CONNECTION_TIMEOUT: Duration = Duration::from_secs(CONNECTION_TIMEOUT_SECONDS);
112-
113135
impl<A, T> Reconnect<A, T>
114136
where
115137
A: Send + 'static,
@@ -133,31 +155,33 @@ where
133155
pub(crate) fn do_work(&self, a: A) -> Result<(), error::Error> {
134156
let mut state = self.0.state.lock().expect("Cannot obtain read lock");
135157
match *state {
136-
NotConnected => {
158+
ReconnectState::NotConnected => {
137159
self.reconnect_spawn(state);
138160
Err(error::Error::Connection(ConnectionReason::NotConnected))
139161
}
140-
Connected(ref t) => {
162+
ReconnectState::Connected(ref t) => {
141163
let success = self.call_work(t, a)?;
142164
if !success {
143-
*state = NotConnected;
165+
*state = ReconnectState::NotConnected;
144166
self.reconnect_spawn(state);
145167
}
146168
Ok(())
147169
}
148-
ConnectionFailed(ref e) => {
170+
ReconnectState::ConnectionFailed(ref e) => {
149171
let mut lock = e.lock().expect("Poisioned lock");
150172
let e = match lock.take() {
151173
Some(e) => e,
152174
None => error::Error::Connection(ConnectionReason::NotConnected),
153175
};
154176
mem::drop(lock);
155177

156-
*state = NotConnected;
178+
*state = ReconnectState::NotConnected;
157179
self.reconnect_spawn(state);
158180
Err(e)
159181
}
160-
Connecting => Err(error::Error::Connection(ConnectionReason::Connecting)),
182+
ReconnectState::Connecting => {
183+
Err(error::Error::Connection(ConnectionReason::Connecting))
184+
}
161185
}
162186
}
163187

@@ -170,17 +194,17 @@ where
170194
log::info!("Attempting to reconnect, current state: {:?}", *state);
171195

172196
match *state {
173-
Connected(_) => {
197+
ReconnectState::Connected(_) => {
174198
return Either::Right(future::err(error::Error::Connection(
175199
ConnectionReason::Connected,
176200
)));
177201
}
178-
Connecting => {
202+
ReconnectState::Connecting => {
179203
return Either::Right(future::err(error::Error::Connection(
180204
ConnectionReason::Connecting,
181205
)));
182206
}
183-
NotConnected | ConnectionFailed(_) => (),
207+
ReconnectState::NotConnected | ReconnectState::ConnectionFailed(_) => (),
184208
}
185209
*state = ReconnectState::Connecting;
186210

@@ -189,33 +213,54 @@ where
189213
let reconnect = self.clone();
190214

191215
let connection_f = async move {
192-
let connection = match timeout(CONNECTION_TIMEOUT, (reconnect.0.conn_fn)()).await {
193-
Ok(con_r) => con_r,
194-
Err(_) => Err(error::internal(format!(
195-
"Connection timed-out after {} seconds",
196-
CONNECTION_TIMEOUT_SECONDS
197-
))),
198-
};
216+
let mut connection_result = Err(error::internal("Initial connection failed"));
217+
for i in 0..reconnect.0.reconnect_options.max_connection_attempts {
218+
log::debug!(
219+
"Connection attempt {}/{}",
220+
i + 1,
221+
reconnect.0.reconnect_options.max_connection_attempts
222+
);
223+
connection_result = match timeout(
224+
reconnect.0.reconnect_options.connection_timeout,
225+
(reconnect.0.conn_fn)(),
226+
)
227+
.await
228+
{
229+
Ok(con_r) => con_r,
230+
Err(_) => Err(error::internal(format!(
231+
"Connection timed-out after {} seconds",
232+
reconnect.0.reconnect_options.connection_timeout.as_secs()
233+
* reconnect.0.reconnect_options.max_connection_attempts
234+
))),
235+
};
236+
if connection_result.is_ok() {
237+
break;
238+
}
239+
}
199240

200241
let mut state = reconnect.0.state.lock().expect("Cannot obtain write lock");
201242

202243
match *state {
203-
NotConnected | Connecting => match connection {
204-
Ok(t) => {
205-
log::info!("Connection established");
206-
*state = Connected(t);
207-
Ok(())
208-
}
209-
Err(e) => {
210-
log::error!("Connection cannot be established: {}", e);
211-
*state = ConnectionFailed(Mutex::new(Some(e)));
212-
Err(error::Error::Connection(ConnectionReason::ConnectionFailed))
244+
ReconnectState::NotConnected | ReconnectState::Connecting => {
245+
match connection_result {
246+
Ok(t) => {
247+
log::info!("Connection established");
248+
*state = ReconnectState::Connected(t);
249+
Ok(())
250+
}
251+
Err(e) => {
252+
log::error!("Connection cannot be established: {}", e);
253+
*state = ReconnectState::ConnectionFailed(Mutex::new(Some(e)));
254+
Err(error::Error::Connection(ConnectionReason::ConnectionFailed))
255+
}
213256
}
214-
},
215-
ConnectionFailed(_) => {
257+
}
258+
ReconnectState::ConnectionFailed(_) => {
216259
panic!("The connection state wasn't reset before connecting")
217260
}
218-
Connected(_) => panic!("A connected state shouldn't be attempting to reconnect"),
261+
ReconnectState::Connected(_) => {
262+
panic!("A connected state shouldn't be attempting to reconnect")
263+
}
219264
}
220265
};
221266

0 commit comments

Comments
 (0)