@@ -8,7 +8,7 @@ use crate::{
88 } ,
99 tag:: verify_tags,
1010} ;
11- use futures:: Future ;
11+ use futures:: { Future , FutureExt } ;
1212use mpc_tls:: { LeaderCtrl , SessionKeys } ;
1313use mpz_common:: Context ;
1414use mpz_vm_core:: Execute ;
@@ -38,21 +38,17 @@ enum State {
3838 mpc : Pin < MpcFuture > ,
3939 fut : Pin < Box < dyn Future < Output = Result < Box < InnerState > , ProverError > > > > ,
4040 } ,
41- SetDecrypt {
42- enable : bool ,
43- mpc : Pin < MpcFuture > ,
44- inner : Box < InnerState > ,
45- } ,
4641 ClientClose {
4742 mpc : Pin < MpcFuture > ,
48- inner : Box < InnerState > ,
43+ fut : Pin < Box < dyn Future < Output = Result < Box < InnerState > , ProverError > > > > ,
4944 } ,
5045 ServerClose {
5146 mpc : Pin < MpcFuture > ,
52- inner : Box < InnerState > ,
47+ fut : Pin < Box < dyn Future < Output = Result < Box < InnerState > , ProverError > > > > ,
5348 } ,
5449 Closing {
55- mpc : Pin < MpcFuture > ,
50+ ctx : Context ,
51+ transcript : TlsTranscript ,
5652 fut : Pin < Box < dyn Future < Output = Result < Box < InnerState > , ProverError > > > > ,
5753 } ,
5854 Finalizing {
@@ -62,13 +58,6 @@ enum State {
6258 Error ,
6359}
6460
65- #[ derive( Debug , Copy , Clone ) ]
66- enum Status {
67- SetDecrypt ( bool ) ,
68- ClientClose ,
69- ServerClose ,
70- }
71-
7261impl MpcTlsClient {
7362 pub ( crate ) fn new (
7463 mpc : MpcFuture ,
@@ -85,6 +74,7 @@ impl MpcTlsClient {
8574 vm,
8675 keys,
8776 mpc_ctrl,
77+ closed : false ,
8878 } ;
8979
9080 Self {
@@ -111,29 +101,6 @@ impl MpcTlsClient {
111101 None
112102 }
113103 }
114-
115- fn set_status ( & mut self , status : Status ) -> bool {
116- match std:: mem:: replace ( & mut self . state , State :: Error ) {
117- State :: Active { inner, mpc } => match status {
118- Status :: SetDecrypt ( enable) => {
119- self . state = State :: SetDecrypt { enable, mpc, inner } ;
120- true
121- }
122- Status :: ClientClose => {
123- self . state = State :: ClientClose { mpc, inner } ;
124- true
125- }
126- Status :: ServerClose => {
127- self . state = State :: ServerClose { mpc, inner } ;
128- true
129- }
130- } ,
131- other => {
132- self . state = other;
133- false
134- }
135- }
136- }
137104}
138105
139106impl TlsClient for MpcTlsClient {
@@ -220,30 +187,57 @@ impl TlsClient for MpcTlsClient {
220187 }
221188
222189 fn client_close ( & mut self ) -> Result < ( ) , Self :: Error > {
223- if self . set_status ( Status :: ClientClose ) {
224- return Ok ( ( ) ) ;
190+ match std:: mem:: replace ( & mut self . state , State :: Error ) {
191+ State :: Active { inner, mpc } => {
192+ self . state = State :: ClientClose {
193+ mpc,
194+ fut : Box :: pin ( inner. client_close ( ) ) ,
195+ } ;
196+ Ok ( ( ) )
197+ }
198+ other => {
199+ self . state = other;
200+ Err ( ProverError :: state (
201+ "unable to close connection, client is not in active state" ,
202+ ) )
203+ }
225204 }
226- Err ( ProverError :: state (
227- "unable to close tls connection, client is not in active state" ,
228- ) )
229205 }
230206
231207 fn server_close ( & mut self ) -> Result < ( ) , Self :: Error > {
232- if self . set_status ( Status :: ServerClose ) {
233- return Ok ( ( ) ) ;
208+ match std:: mem:: replace ( & mut self . state , State :: Error ) {
209+ State :: Active { inner, mpc } => {
210+ self . state = State :: ServerClose {
211+ mpc,
212+ fut : Box :: pin ( inner. server_close ( ) ) ,
213+ } ;
214+ Ok ( ( ) )
215+ }
216+ other => {
217+ self . state = other;
218+ Err ( ProverError :: state (
219+ "unable to close connection, client is not in active state" ,
220+ ) )
221+ }
234222 }
235- Err ( ProverError :: state (
236- "unable to close tls connection, client is not in active state" ,
237- ) )
238223 }
239224
240225 fn enable_decryption ( & mut self , enable : bool ) -> Result < ( ) , Self :: Error > {
241- if self . set_status ( Status :: SetDecrypt ( enable) ) {
242- return Ok ( ( ) ) ;
226+ match std:: mem:: replace ( & mut self . state , State :: Error ) {
227+ State :: Active { inner, mpc } => {
228+ self . state = State :: Busy {
229+ mpc,
230+ fut : Box :: pin ( inner. set_decrypt ( enable) ) ,
231+ } ;
232+ Ok ( ( ) )
233+ }
234+ other => {
235+ self . state = other;
236+ Err ( ProverError :: state (
237+ "unable to enable decryption, client is not in active state" ,
238+ ) )
239+ }
243240 }
244- Err ( ProverError :: state (
245- "unable to set decryption, client is not in active state" ,
246- ) )
247241 }
248242
249243 fn is_decrypting ( & self ) -> bool {
@@ -263,6 +257,9 @@ impl TlsClient for MpcTlsClient {
263257 }
264258 State :: Busy { mut mpc, mut fut } => {
265259 trace ! ( "inner client is busy" ) ;
260+
261+ // mpc future cannot be ready at this point becaus we have not called `stop`
262+ // yet.
266263 let _ = mpc. as_mut ( ) . poll ( cx) ?;
267264
268265 match fut. as_mut ( ) . poll ( cx) {
@@ -274,56 +271,78 @@ impl TlsClient for MpcTlsClient {
274271 }
275272 Poll :: Pending
276273 }
277- State :: SetDecrypt { enable, mpc, inner } => {
278- self . state = State :: Busy {
279- mpc,
280- fut : Box :: pin ( inner. set_decrypt ( enable) ) ,
281- } ;
282- self . decrypt = enable;
283-
284- debug ! ( "set decryption to {}" , enable) ;
285- self . poll ( cx)
286- }
287- State :: ClientClose { mpc, inner } => {
274+ State :: ClientClose { mut mpc, mut fut } => {
288275 debug ! ( "attempting to close connection clientside" ) ;
289- self . state = State :: Closing {
290- mpc,
291- fut : Box :: pin ( inner. client_close ( ) ) ,
292- } ;
276+ match ( fut. poll_unpin ( cx) ?, mpc. poll_unpin ( cx) ?) {
277+ ( Poll :: Ready ( inner) , Poll :: Ready ( ( ctx, transcript) ) ) => {
278+ self . state = State :: Finalizing {
279+ fut : Box :: pin ( inner. finalize ( ctx, transcript) ) ,
280+ } ;
281+ }
282+ ( Poll :: Ready ( inner) , Poll :: Pending ) => {
283+ self . state = State :: ClientClose {
284+ mpc,
285+ fut : Box :: pin ( inner. client_close ( ) ) ,
286+ } ;
287+ }
288+ ( Poll :: Pending , Poll :: Ready ( ( ctx, transcript) ) ) => {
289+ self . state = State :: Closing {
290+ ctx,
291+ transcript,
292+ fut,
293+ } ;
294+ }
295+ ( Poll :: Pending , Poll :: Pending ) => self . state = State :: ClientClose { mpc, fut } ,
296+ }
293297 self . poll ( cx)
294298 }
295- State :: ServerClose { mpc, inner } => {
299+ State :: ServerClose { mut mpc, mut fut } => {
296300 debug ! ( "attempting to close connection serverside" ) ;
297- self . state = State :: Closing {
298- mpc,
299- fut : Box :: pin ( inner. server_close ( ) ) ,
300- } ;
301+ match ( fut. poll_unpin ( cx) ?, mpc. poll_unpin ( cx) ?) {
302+ ( Poll :: Ready ( inner) , Poll :: Ready ( ( ctx, transcript) ) ) => {
303+ self . state = State :: Finalizing {
304+ fut : Box :: pin ( inner. finalize ( ctx, transcript) ) ,
305+ } ;
306+ }
307+ ( Poll :: Ready ( inner) , Poll :: Pending ) => {
308+ self . state = State :: ServerClose {
309+ mpc,
310+ fut : Box :: pin ( inner. server_close ( ) ) ,
311+ } ;
312+ }
313+ ( Poll :: Pending , Poll :: Ready ( ( ctx, transcript) ) ) => {
314+ self . state = State :: Closing {
315+ ctx,
316+ transcript,
317+ fut,
318+ } ;
319+ }
320+ ( Poll :: Pending , Poll :: Pending ) => self . state = State :: ServerClose { mpc, fut } ,
321+ }
301322 self . poll ( cx)
302323 }
303- State :: Closing { mut mpc, mut fut } => {
304- let Poll :: Ready ( res) = fut. as_mut ( ) . poll ( cx) else {
305- self . state = State :: Closing { mpc, fut } ;
306- return Poll :: Pending ;
307- } ;
308- let inner = res?;
309- let Poll :: Ready ( res) = mpc. as_mut ( ) . poll ( cx) else {
324+ State :: Closing {
325+ ctx,
326+ transcript,
327+ mut fut,
328+ } => {
329+ if let Poll :: Ready ( inner) = fut. poll_unpin ( cx) ? {
330+ self . state = State :: Finalizing {
331+ fut : Box :: pin ( inner. finalize ( ctx, transcript) ) ,
332+ } ;
333+ } else {
310334 self . state = State :: Closing {
311- mpc,
312- fut : Box :: pin ( inner. client_close ( ) ) ,
335+ ctx,
336+ transcript,
337+ fut,
313338 } ;
314- return Poll :: Pending ;
315- } ;
316-
317- let ( ctx, transcript) = res?;
318- self . state = State :: Finalizing {
319- fut : Box :: pin ( inner. finalize ( ctx, transcript) ) ,
320- } ;
339+ }
321340 self . poll ( cx)
322341 }
323342 State :: Finalizing { mut fut } => match fut. as_mut ( ) . poll ( cx) {
324343 Poll :: Ready ( output) => {
325- let ( state , ctx, tls_transcript) = output?;
326- let InnerState { vm, keys, .. } = state ;
344+ let ( inner , ctx, tls_transcript) = output?;
345+ let InnerState { vm, keys, .. } = inner ;
327346
328347 let transcript = tls_transcript
329348 . to_transcript ( )
@@ -366,6 +385,7 @@ struct InnerState {
366385 vm : Arc < Mutex < Deap < ProverMpc , ProverZk > > > ,
367386 keys : SessionKeys ,
368387 mpc_ctrl : LeaderCtrl ,
388+ closed : bool ,
369389}
370390
371391impl InnerState {
@@ -377,23 +397,25 @@ impl InnerState {
377397
378398 #[ instrument( parent = & self . span, level = "debug" , skip_all, err) ]
379399 async fn client_close ( mut self : Box < Self > ) -> Result < Box < Self > , ProverError > {
380- if self . tls . plaintext_is_empty ( ) && self . tls . is_empty ( ) . await ? {
400+ if self . tls . plaintext_is_empty ( ) && self . tls . is_empty ( ) . await ? && ! self . closed {
381401 if let Err ( e) = self . tls . send_close_notify ( ) . await {
382402 warn ! ( "failed to send close_notify to server: {}" , e) ;
383403 } ;
384404
385405 self . mpc_ctrl . stop ( ) . await ?;
406+ self . closed = true ;
386407 debug ! ( "closed connection" ) ;
387408 }
388409 self . run ( ) . await
389410 }
390411
391412 #[ instrument( parent = & self . span, level = "debug" , skip_all, err) ]
392413 async fn server_close ( mut self : Box < Self > ) -> Result < Box < Self > , ProverError > {
393- if self . tls . plaintext_is_empty ( ) && self . tls . is_empty ( ) . await ? {
414+ if self . tls . plaintext_is_empty ( ) && self . tls . is_empty ( ) . await ? && ! self . closed {
394415 self . tls . server_closed ( ) . await ?;
395416
396417 self . mpc_ctrl . stop ( ) . await ?;
418+ self . closed = true ;
397419 debug ! ( "closed connection" ) ;
398420 }
399421 self . run ( ) . await
0 commit comments