@@ -30,35 +30,29 @@ pub(crate) struct MpcTlsClient {
3030}
3131
3232enum State {
33- ActiveIdle {
34- inner : InnerState ,
33+ Active {
34+ mpc : Pin < MpcFuture > ,
35+ inner : Box < InnerState > ,
3536 } ,
36- ActiveBusy {
37+ Busy {
38+ mpc : Pin < MpcFuture > ,
3739 fut : Pin < Box < dyn Future < Output = Result < Box < InnerState > , ProverError > > > > ,
3840 } ,
39- SetDecryptIdle {
41+ SetDecrypt {
4042 enable : bool ,
41- inner : InnerState ,
43+ mpc : Pin < MpcFuture > ,
44+ inner : Box < InnerState > ,
4245 } ,
43- SetDecryptBusy {
44- fut : Pin < Box < dyn Future < Output = Result < Box < InnerState > , ProverError > > > > ,
45- } ,
46- ClientCloseIdle {
47- inner : InnerState ,
48- } ,
49- ClientCloseBusy {
50- fut : Pin < Box < dyn Future < Output = Result < Box < InnerState > , ProverError > > > > ,
51- } ,
52- ServerCloseIdle {
53- inner : InnerState ,
54- } ,
55- ServerCloseBusy {
56- fut : Pin < Box < dyn Future < Output = Result < Box < InnerState > , ProverError > > > > ,
46+ ClientClose {
47+ mpc : Pin < MpcFuture > ,
48+ inner : Box < InnerState > ,
5749 } ,
58- ClosingIdle {
59- inner : InnerState ,
50+ ServerClose {
51+ mpc : Pin < MpcFuture > ,
52+ inner : Box < InnerState > ,
6053 } ,
61- ClosingBusy {
54+ Closing {
55+ mpc : Pin < MpcFuture > ,
6256 fut : Pin < Box < dyn Future < Output = Result < Box < InnerState > , ProverError > > > > ,
6357 } ,
6458 Finalizing {
@@ -85,31 +79,33 @@ impl MpcTlsClient {
8579 tls : ClientConnection ,
8680 decrypt : bool ,
8781 ) -> Self {
82+ let inner = InnerState {
83+ span,
84+ tls,
85+ vm,
86+ keys,
87+ mpc_ctrl,
88+ } ;
89+
8890 Self {
89- state : State :: ActiveIdle {
90- inner : InnerState {
91- span,
92- tls,
93- mpc : Box :: into_pin ( mpc) ,
94- vm,
95- keys,
96- mpc_ctrl,
97- decrypt,
98- } ,
91+ decrypt,
92+ state : State :: Active {
93+ mpc : Box :: into_pin ( mpc) ,
94+ inner : Box :: new ( inner) ,
9995 } ,
10096 }
10197 }
10298
10399 fn inner_client_mut ( & mut self ) -> Option < & mut ClientConnection > {
104- if let State :: ActiveIdle { inner } = & mut self . state {
100+ if let State :: Active { inner, .. } = & mut self . state {
105101 Some ( & mut inner. tls )
106102 } else {
107103 None
108104 }
109105 }
110106
111107 fn inner_client ( & self ) -> Option < & ClientConnection > {
112- if let State :: ActiveIdle { inner } = & self . state {
108+ if let State :: Active { inner, .. } = & self . state {
113109 Some ( & inner. tls )
114110 } else {
115111 None
@@ -118,17 +114,17 @@ impl MpcTlsClient {
118114
119115 fn set_status ( & mut self , status : Status ) -> bool {
120116 match std:: mem:: replace ( & mut self . state , State :: Error ) {
121- State :: ActiveIdle { inner } => match status {
117+ State :: Active { inner, mpc } => match status {
122118 Status :: SetDecrypt ( enable) => {
123- self . state = State :: SetDecryptIdle { enable, inner } ;
119+ self . state = State :: SetDecrypt { enable, mpc , inner } ;
124120 true
125121 }
126122 Status :: ClientClose => {
127- self . state = State :: ClientCloseIdle { inner } ;
123+ self . state = State :: ClientClose { mpc , inner } ;
128124 true
129125 }
130126 Status :: ServerClose => {
131- self . state = State :: ServerCloseIdle { inner } ;
127+ self . state = State :: ServerClose { mpc , inner } ;
132128 true
133129 }
134130 } ,
@@ -256,30 +252,74 @@ impl TlsClient for MpcTlsClient {
256252
257253 fn poll ( & mut self , cx : & mut std:: task:: Context ) -> Poll < Result < TlsOutput , Self :: Error > > {
258254 match std:: mem:: replace ( & mut self . state , State :: Error ) {
259- State :: ActiveIdle { inner } => {
260- trace ! ( "inner client is active idle" ) ;
261- let _ = inner. mpc . as_mut ( ) . poll ( cx) ?;
262- self . state = State :: ActiveBusy { fut : inner. run ( ) }
255+ State :: Active { mpc, inner } => {
256+ trace ! ( "inner client is active" ) ;
257+
258+ self . state = State :: Busy {
259+ mpc,
260+ fut : Box :: pin ( inner. run ( ) ) ,
261+ } ;
262+ self . poll ( cx)
263263 }
264- State :: ActiveBusy { fut } => match fut. as_mut ( ) . poll ( cx) {
265- Poll :: Ready ( res) => {
266- let inner = res?;
267- self . state . decrypt = * inner. decrypt ;
268- self . state = State :: ActiveIdle { inner } ;
264+ State :: Busy { mut mpc, mut fut } => {
265+ trace ! ( "inner client is busy" ) ;
266+ let _ = mpc. as_mut ( ) . poll ( cx) ?;
267+
268+ match fut. as_mut ( ) . poll ( cx) {
269+ Poll :: Ready ( res) => {
270+ let inner = res?;
271+ self . state = State :: Active { mpc, inner } ;
272+ }
273+ Poll :: Pending => self . state = State :: Busy { mpc, fut } ,
269274 }
270- Poll :: Pending => todo ! ( ) ,
271- } ,
272- State :: SetDecryptIdle {
273- enable,
274- inner : state,
275- } => todo ! ( ) ,
276- State :: SetDecryptBusy { fut } => todo ! ( ) ,
277- State :: ClientCloseIdle { inner : state } => todo ! ( ) ,
278- State :: ClientCloseBusy { fut } => todo ! ( ) ,
279- State :: ServerCloseIdle { inner : state } => todo ! ( ) ,
280- State :: ServerCloseBusy { fut } => todo ! ( ) ,
281- State :: ClosingIdle { inner : state } => todo ! ( ) ,
282- State :: ClosingBusy { fut } => todo ! ( ) ,
275+ Poll :: Pending
276+ }
277+ State :: SetDecrypt { enable, mpc, inner } => {
278+ self . state = State :: Busy {
279+ mpc,
280+ fut : Box :: pin ( inner. set_decrypt ( enable) ) ,
281+ } ;
282+ self . decrypt = enable;
283+
284+ debug ! ( "set decryption to {}" , enable) ;
285+ self . poll ( cx)
286+ }
287+ State :: ClientClose { mpc, inner } => {
288+ debug ! ( "attempting to close connection clientside" ) ;
289+ self . state = State :: Closing {
290+ mpc,
291+ fut : Box :: pin ( inner. client_close ( ) ) ,
292+ } ;
293+ self . poll ( cx)
294+ }
295+ State :: ServerClose { mpc, inner } => {
296+ debug ! ( "attempting to close connection serverside" ) ;
297+ self . state = State :: Closing {
298+ mpc,
299+ fut : Box :: pin ( inner. server_close ( ) ) ,
300+ } ;
301+ self . poll ( cx)
302+ }
303+ State :: Closing { mut mpc, mut fut } => {
304+ let Poll :: Ready ( res) = fut. as_mut ( ) . poll ( cx) else {
305+ self . state = State :: Closing { mpc, fut } ;
306+ return Poll :: Pending ;
307+ } ;
308+ let inner = res?;
309+ let Poll :: Ready ( res) = mpc. as_mut ( ) . poll ( cx) else {
310+ self . state = State :: Closing {
311+ mpc,
312+ fut : Box :: pin ( inner. client_close ( ) ) ,
313+ } ;
314+ return Poll :: Pending ;
315+ } ;
316+
317+ let ( ctx, transcript) = res?;
318+ self . state = State :: Finalizing {
319+ fut : Box :: pin ( inner. finalize ( ctx, transcript) ) ,
320+ } ;
321+ self . poll ( cx)
322+ }
283323 State :: Finalizing { mut fut } => match fut. as_mut ( ) . poll ( cx) {
284324 Poll :: Ready ( output) => {
285325 let ( state, ctx, tls_transcript) = output?;
@@ -323,18 +363,45 @@ impl TlsClient for MpcTlsClient {
323363struct InnerState {
324364 span : Span ,
325365 tls : ClientConnection ,
326- mpc : Pin < MpcFuture > ,
327366 vm : Arc < Mutex < Deap < ProverMpc , ProverZk > > > ,
328367 keys : SessionKeys ,
329368 mpc_ctrl : LeaderCtrl ,
330369}
331370
332371impl InnerState {
333372 #[ instrument( parent = & self . span, level = "debug" , skip_all, err) ]
373+ async fn set_decrypt ( self : Box < Self > , enable : bool ) -> Result < Box < Self > , ProverError > {
374+ self . mpc_ctrl . enable_decryption ( enable) . await ?;
375+ self . run ( ) . await
376+ }
377+
378+ #[ instrument( parent = & self . span, level = "debug" , skip_all, err) ]
379+ async fn client_close ( mut self : Box < Self > ) -> Result < Box < Self > , ProverError > {
380+ if self . tls . plaintext_is_empty ( ) && self . tls . is_empty ( ) . await ? {
381+ if let Err ( e) = self . tls . send_close_notify ( ) . await {
382+ warn ! ( "failed to send close_notify to server: {}" , e) ;
383+ } ;
384+
385+ self . mpc_ctrl . stop ( ) . await ?;
386+ debug ! ( "closed connection" ) ;
387+ }
388+ self . run ( ) . await
389+ }
390+
391+ #[ instrument( parent = & self . span, level = "debug" , skip_all, err) ]
392+ async fn server_close ( mut self : Box < Self > ) -> Result < Box < Self > , ProverError > {
393+ if self . tls . plaintext_is_empty ( ) && self . tls . is_empty ( ) . await ? {
394+ self . tls . server_closed ( ) . await ?;
395+
396+ self . mpc_ctrl . stop ( ) . await ?;
397+ debug ! ( "closed connection" ) ;
398+ }
399+ self . run ( ) . await
400+ }
401+
402+ #[ instrument( parent = & self . span, level = "trace" , skip_all, err) ]
334403 async fn run ( mut self : Box < Self > ) -> Result < Box < Self > , ProverError > {
335- trace ! ( "processing new packets" ) ;
336404 self . tls . process_new_packets ( ) . await ?;
337-
338405 Ok ( self )
339406 }
340407
0 commit comments