11/*
2- * Copyright 2018-2020 Ben Ashford
2+ * Copyright 2018-2025 Ben Ashford
33 *
44 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
55 * http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
@@ -28,10 +28,31 @@ type WorkFn<T, A> = dyn Fn(&T, A) -> Result<(), error::Error> + Send + Sync;
2828type ConnFn < T > =
2929 dyn Fn ( ) -> Pin < Box < dyn Future < Output = Result < T , error:: Error > > + Send + Sync > > + Send + Sync ;
3030
31+ const CONNECTION_TIMEOUT_SECONDS : u64 = 1 ;
32+ const MAX_CONNECTION_ATTEMPTS : u64 = 10 ;
33+ const CONNECTION_TIMEOUT : Duration = Duration :: from_secs ( CONNECTION_TIMEOUT_SECONDS ) ;
34+
35+ #[ derive( Debug , Copy , Clone ) ]
36+ pub ( crate ) struct ReconnectOptions {
37+ pub ( crate ) connection_timeout : Duration ,
38+ pub ( crate ) max_connection_attempts : u64 ,
39+ }
40+
41+ impl Default for ReconnectOptions {
42+ #[ inline]
43+ fn default ( ) -> Self {
44+ ReconnectOptions {
45+ connection_timeout : CONNECTION_TIMEOUT ,
46+ max_connection_attempts : MAX_CONNECTION_ATTEMPTS ,
47+ }
48+ }
49+ }
50+
3151struct ReconnectInner < A , T > {
3252 state : Mutex < ReconnectState < T > > ,
3353 work_fn : Box < WorkFn < T , A > > ,
3454 conn_fn : Box < ConnFn < T > > ,
55+ reconnect_options : ReconnectOptions ,
3556}
3657
3758impl < A , T > fmt:: Debug for ReconnectInner < A , T > {
@@ -62,7 +83,11 @@ impl<A, T> Clone for Reconnect<A, T> {
6283 }
6384}
6485
65- pub ( crate ) async fn reconnect < A , T , W , C > ( w : W , c : C ) -> Result < Reconnect < A , T > , error:: Error >
86+ pub ( crate ) async fn reconnect < A , T , W , C > (
87+ w : W ,
88+ c : C ,
89+ options : ReconnectOptions ,
90+ ) -> Result < Reconnect < A , T > , error:: Error >
6691where
6792 A : Send + ' static ,
6893 W : Fn ( & T , A ) -> Result < ( ) , error:: Error > + Send + Sync + ' static ,
77102
78103 work_fn : Box :: new ( w) ,
79104 conn_fn : Box :: new ( c) ,
105+
106+ reconnect_options : options,
80107 } ) ) ;
81108 let rf = {
82109 let state = r. 0 . state . lock ( ) . expect ( "Poisoned lock" ) ;
@@ -97,19 +124,14 @@ impl<T> fmt::Debug for ReconnectState<T> {
97124 fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
98125 write ! ( f, "ReconnectState::" ) ?;
99126 match self {
100- NotConnected => write ! ( f, "NotConnected" ) ,
101- Connected ( _) => write ! ( f, "Connected" ) ,
102- ConnectionFailed ( _) => write ! ( f, "ConnectionFailed" ) ,
103- Connecting => write ! ( f, "Connecting" ) ,
127+ Self :: NotConnected => write ! ( f, "NotConnected" ) ,
128+ Self :: Connected ( _) => write ! ( f, "Connected" ) ,
129+ Self :: ConnectionFailed ( _) => write ! ( f, "ConnectionFailed" ) ,
130+ Self :: Connecting => write ! ( f, "Connecting" ) ,
104131 }
105132 }
106133}
107134
108- use self :: ReconnectState :: * ;
109-
110- const CONNECTION_TIMEOUT_SECONDS : u64 = 10 ;
111- const CONNECTION_TIMEOUT : Duration = Duration :: from_secs ( CONNECTION_TIMEOUT_SECONDS ) ;
112-
113135impl < A , T > Reconnect < A , T >
114136where
115137 A : Send + ' static ,
@@ -133,31 +155,33 @@ where
133155 pub ( crate ) fn do_work ( & self , a : A ) -> Result < ( ) , error:: Error > {
134156 let mut state = self . 0 . state . lock ( ) . expect ( "Cannot obtain read lock" ) ;
135157 match * state {
136- NotConnected => {
158+ ReconnectState :: NotConnected => {
137159 self . reconnect_spawn ( state) ;
138160 Err ( error:: Error :: Connection ( ConnectionReason :: NotConnected ) )
139161 }
140- Connected ( ref t) => {
162+ ReconnectState :: Connected ( ref t) => {
141163 let success = self . call_work ( t, a) ?;
142164 if !success {
143- * state = NotConnected ;
165+ * state = ReconnectState :: NotConnected ;
144166 self . reconnect_spawn ( state) ;
145167 }
146168 Ok ( ( ) )
147169 }
148- ConnectionFailed ( ref e) => {
170+ ReconnectState :: ConnectionFailed ( ref e) => {
149171 let mut lock = e. lock ( ) . expect ( "Poisioned lock" ) ;
150172 let e = match lock. take ( ) {
151173 Some ( e) => e,
152174 None => error:: Error :: Connection ( ConnectionReason :: NotConnected ) ,
153175 } ;
154176 mem:: drop ( lock) ;
155177
156- * state = NotConnected ;
178+ * state = ReconnectState :: NotConnected ;
157179 self . reconnect_spawn ( state) ;
158180 Err ( e)
159181 }
160- Connecting => Err ( error:: Error :: Connection ( ConnectionReason :: Connecting ) ) ,
182+ ReconnectState :: Connecting => {
183+ Err ( error:: Error :: Connection ( ConnectionReason :: Connecting ) )
184+ }
161185 }
162186 }
163187
@@ -170,17 +194,17 @@ where
170194 log:: info!( "Attempting to reconnect, current state: {:?}" , * state) ;
171195
172196 match * state {
173- Connected ( _) => {
197+ ReconnectState :: Connected ( _) => {
174198 return Either :: Right ( future:: err ( error:: Error :: Connection (
175199 ConnectionReason :: Connected ,
176200 ) ) ) ;
177201 }
178- Connecting => {
202+ ReconnectState :: Connecting => {
179203 return Either :: Right ( future:: err ( error:: Error :: Connection (
180204 ConnectionReason :: Connecting ,
181205 ) ) ) ;
182206 }
183- NotConnected | ConnectionFailed ( _) => ( ) ,
207+ ReconnectState :: NotConnected | ReconnectState :: ConnectionFailed ( _) => ( ) ,
184208 }
185209 * state = ReconnectState :: Connecting ;
186210
@@ -189,33 +213,54 @@ where
189213 let reconnect = self . clone ( ) ;
190214
191215 let connection_f = async move {
192- let connection = match timeout ( CONNECTION_TIMEOUT , ( reconnect. 0 . conn_fn ) ( ) ) . await {
193- Ok ( con_r) => con_r,
194- Err ( _) => Err ( error:: internal ( format ! (
195- "Connection timed-out after {} seconds" ,
196- CONNECTION_TIMEOUT_SECONDS
197- ) ) ) ,
198- } ;
216+ let mut connection_result = Err ( error:: internal ( "Initial connection failed" ) ) ;
217+ for i in 0 ..reconnect. 0 . reconnect_options . max_connection_attempts {
218+ log:: debug!(
219+ "Connection attempt {}/{}" ,
220+ i + 1 ,
221+ reconnect. 0 . reconnect_options. max_connection_attempts
222+ ) ;
223+ connection_result = match timeout (
224+ reconnect. 0 . reconnect_options . connection_timeout ,
225+ ( reconnect. 0 . conn_fn ) ( ) ,
226+ )
227+ . await
228+ {
229+ Ok ( con_r) => con_r,
230+ Err ( _) => Err ( error:: internal ( format ! (
231+ "Connection timed-out after {} seconds" ,
232+ reconnect. 0 . reconnect_options. connection_timeout. as_secs( )
233+ * reconnect. 0 . reconnect_options. max_connection_attempts
234+ ) ) ) ,
235+ } ;
236+ if connection_result. is_ok ( ) {
237+ break ;
238+ }
239+ }
199240
200241 let mut state = reconnect. 0 . state . lock ( ) . expect ( "Cannot obtain write lock" ) ;
201242
202243 match * state {
203- NotConnected | Connecting => match connection {
204- Ok ( t) => {
205- log:: info!( "Connection established" ) ;
206- * state = Connected ( t) ;
207- Ok ( ( ) )
208- }
209- Err ( e) => {
210- log:: error!( "Connection cannot be established: {}" , e) ;
211- * state = ConnectionFailed ( Mutex :: new ( Some ( e) ) ) ;
212- Err ( error:: Error :: Connection ( ConnectionReason :: ConnectionFailed ) )
244+ ReconnectState :: NotConnected | ReconnectState :: Connecting => {
245+ match connection_result {
246+ Ok ( t) => {
247+ log:: info!( "Connection established" ) ;
248+ * state = ReconnectState :: Connected ( t) ;
249+ Ok ( ( ) )
250+ }
251+ Err ( e) => {
252+ log:: error!( "Connection cannot be established: {}" , e) ;
253+ * state = ReconnectState :: ConnectionFailed ( Mutex :: new ( Some ( e) ) ) ;
254+ Err ( error:: Error :: Connection ( ConnectionReason :: ConnectionFailed ) )
255+ }
213256 }
214- } ,
215- ConnectionFailed ( _) => {
257+ }
258+ ReconnectState :: ConnectionFailed ( _) => {
216259 panic ! ( "The connection state wasn't reset before connecting" )
217260 }
218- Connected ( _) => panic ! ( "A connected state shouldn't be attempting to reconnect" ) ,
261+ ReconnectState :: Connected ( _) => {
262+ panic ! ( "A connected state shouldn't be attempting to reconnect" )
263+ }
219264 }
220265 } ;
221266
0 commit comments