Skip to content

Commit 5935604

Browse files
committed
add correct state transitions
1 parent 2ddc731 commit 5935604

File tree

1 file changed

+116
-94
lines changed
  • crates/tlsn/src/prover/client

1 file changed

+116
-94
lines changed

crates/tlsn/src/prover/client/mpc.rs

Lines changed: 116 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::{
88
},
99
tag::verify_tags,
1010
};
11-
use futures::Future;
11+
use futures::{Future, FutureExt};
1212
use mpc_tls::{LeaderCtrl, SessionKeys};
1313
use mpz_common::Context;
1414
use mpz_vm_core::Execute;
@@ -38,21 +38,17 @@ enum State {
3838
mpc: Pin<MpcFuture>,
3939
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
4040
},
41-
SetDecrypt {
42-
enable: bool,
43-
mpc: Pin<MpcFuture>,
44-
inner: Box<InnerState>,
45-
},
4641
ClientClose {
4742
mpc: Pin<MpcFuture>,
48-
inner: Box<InnerState>,
43+
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
4944
},
5045
ServerClose {
5146
mpc: Pin<MpcFuture>,
52-
inner: Box<InnerState>,
47+
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
5348
},
5449
Closing {
55-
mpc: Pin<MpcFuture>,
50+
ctx: Context,
51+
transcript: TlsTranscript,
5652
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
5753
},
5854
Finalizing {
@@ -62,13 +58,6 @@ enum State {
6258
Error,
6359
}
6460

65-
#[derive(Debug, Copy, Clone)]
66-
enum Status {
67-
SetDecrypt(bool),
68-
ClientClose,
69-
ServerClose,
70-
}
71-
7261
impl MpcTlsClient {
7362
pub(crate) fn new(
7463
mpc: MpcFuture,
@@ -85,6 +74,7 @@ impl MpcTlsClient {
8574
vm,
8675
keys,
8776
mpc_ctrl,
77+
closed: false,
8878
};
8979

9080
Self {
@@ -111,29 +101,6 @@ impl MpcTlsClient {
111101
None
112102
}
113103
}
114-
115-
fn set_status(&mut self, status: Status) -> bool {
116-
match std::mem::replace(&mut self.state, State::Error) {
117-
State::Active { inner, mpc } => match status {
118-
Status::SetDecrypt(enable) => {
119-
self.state = State::SetDecrypt { enable, mpc, inner };
120-
true
121-
}
122-
Status::ClientClose => {
123-
self.state = State::ClientClose { mpc, inner };
124-
true
125-
}
126-
Status::ServerClose => {
127-
self.state = State::ServerClose { mpc, inner };
128-
true
129-
}
130-
},
131-
other => {
132-
self.state = other;
133-
false
134-
}
135-
}
136-
}
137104
}
138105

139106
impl TlsClient for MpcTlsClient {
@@ -220,30 +187,57 @@ impl TlsClient for MpcTlsClient {
220187
}
221188

222189
fn client_close(&mut self) -> Result<(), Self::Error> {
223-
if self.set_status(Status::ClientClose) {
224-
return Ok(());
190+
match std::mem::replace(&mut self.state, State::Error) {
191+
State::Active { inner, mpc } => {
192+
self.state = State::ClientClose {
193+
mpc,
194+
fut: Box::pin(inner.client_close()),
195+
};
196+
Ok(())
197+
}
198+
other => {
199+
self.state = other;
200+
Err(ProverError::state(
201+
"unable to close connection, client is not in active state",
202+
))
203+
}
225204
}
226-
Err(ProverError::state(
227-
"unable to close tls connection, client is not in active state",
228-
))
229205
}
230206

231207
fn server_close(&mut self) -> Result<(), Self::Error> {
232-
if self.set_status(Status::ServerClose) {
233-
return Ok(());
208+
match std::mem::replace(&mut self.state, State::Error) {
209+
State::Active { inner, mpc } => {
210+
self.state = State::ServerClose {
211+
mpc,
212+
fut: Box::pin(inner.server_close()),
213+
};
214+
Ok(())
215+
}
216+
other => {
217+
self.state = other;
218+
Err(ProverError::state(
219+
"unable to close connection, client is not in active state",
220+
))
221+
}
234222
}
235-
Err(ProverError::state(
236-
"unable to close tls connection, client is not in active state",
237-
))
238223
}
239224

240225
fn enable_decryption(&mut self, enable: bool) -> Result<(), Self::Error> {
241-
if self.set_status(Status::SetDecrypt(enable)) {
242-
return Ok(());
226+
match std::mem::replace(&mut self.state, State::Error) {
227+
State::Active { inner, mpc } => {
228+
self.state = State::Busy {
229+
mpc,
230+
fut: Box::pin(inner.set_decrypt(enable)),
231+
};
232+
Ok(())
233+
}
234+
other => {
235+
self.state = other;
236+
Err(ProverError::state(
237+
"unable to enable decryption, client is not in active state",
238+
))
239+
}
243240
}
244-
Err(ProverError::state(
245-
"unable to set decryption, client is not in active state",
246-
))
247241
}
248242

249243
fn is_decrypting(&self) -> bool {
@@ -263,6 +257,9 @@ impl TlsClient for MpcTlsClient {
263257
}
264258
State::Busy { mut mpc, mut fut } => {
265259
trace!("inner client is busy");
260+
261+
// mpc future cannot be ready at this point becaus we have not called `stop`
262+
// yet.
266263
let _ = mpc.as_mut().poll(cx)?;
267264

268265
match fut.as_mut().poll(cx) {
@@ -274,56 +271,78 @@ impl TlsClient for MpcTlsClient {
274271
}
275272
Poll::Pending
276273
}
277-
State::SetDecrypt { enable, mpc, inner } => {
278-
self.state = State::Busy {
279-
mpc,
280-
fut: Box::pin(inner.set_decrypt(enable)),
281-
};
282-
self.decrypt = enable;
283-
284-
debug!("set decryption to {}", enable);
285-
self.poll(cx)
286-
}
287-
State::ClientClose { mpc, inner } => {
274+
State::ClientClose { mut mpc, mut fut } => {
288275
debug!("attempting to close connection clientside");
289-
self.state = State::Closing {
290-
mpc,
291-
fut: Box::pin(inner.client_close()),
292-
};
276+
match (fut.poll_unpin(cx)?, mpc.poll_unpin(cx)?) {
277+
(Poll::Ready(inner), Poll::Ready((ctx, transcript))) => {
278+
self.state = State::Finalizing {
279+
fut: Box::pin(inner.finalize(ctx, transcript)),
280+
};
281+
}
282+
(Poll::Ready(inner), Poll::Pending) => {
283+
self.state = State::ClientClose {
284+
mpc,
285+
fut: Box::pin(inner.client_close()),
286+
};
287+
}
288+
(Poll::Pending, Poll::Ready((ctx, transcript))) => {
289+
self.state = State::Closing {
290+
ctx,
291+
transcript,
292+
fut,
293+
};
294+
}
295+
(Poll::Pending, Poll::Pending) => self.state = State::ClientClose { mpc, fut },
296+
}
293297
self.poll(cx)
294298
}
295-
State::ServerClose { mpc, inner } => {
299+
State::ServerClose { mut mpc, mut fut } => {
296300
debug!("attempting to close connection serverside");
297-
self.state = State::Closing {
298-
mpc,
299-
fut: Box::pin(inner.server_close()),
300-
};
301+
match (fut.poll_unpin(cx)?, mpc.poll_unpin(cx)?) {
302+
(Poll::Ready(inner), Poll::Ready((ctx, transcript))) => {
303+
self.state = State::Finalizing {
304+
fut: Box::pin(inner.finalize(ctx, transcript)),
305+
};
306+
}
307+
(Poll::Ready(inner), Poll::Pending) => {
308+
self.state = State::ServerClose {
309+
mpc,
310+
fut: Box::pin(inner.server_close()),
311+
};
312+
}
313+
(Poll::Pending, Poll::Ready((ctx, transcript))) => {
314+
self.state = State::Closing {
315+
ctx,
316+
transcript,
317+
fut,
318+
};
319+
}
320+
(Poll::Pending, Poll::Pending) => self.state = State::ServerClose { mpc, fut },
321+
}
301322
self.poll(cx)
302323
}
303-
State::Closing { mut mpc, mut fut } => {
304-
let Poll::Ready(res) = fut.as_mut().poll(cx) else {
305-
self.state = State::Closing { mpc, fut };
306-
return Poll::Pending;
307-
};
308-
let inner = res?;
309-
let Poll::Ready(res) = mpc.as_mut().poll(cx) else {
324+
State::Closing {
325+
ctx,
326+
transcript,
327+
mut fut,
328+
} => {
329+
if let Poll::Ready(inner) = fut.poll_unpin(cx)? {
330+
self.state = State::Finalizing {
331+
fut: Box::pin(inner.finalize(ctx, transcript)),
332+
};
333+
} else {
310334
self.state = State::Closing {
311-
mpc,
312-
fut: Box::pin(inner.client_close()),
335+
ctx,
336+
transcript,
337+
fut,
313338
};
314-
return Poll::Pending;
315-
};
316-
317-
let (ctx, transcript) = res?;
318-
self.state = State::Finalizing {
319-
fut: Box::pin(inner.finalize(ctx, transcript)),
320-
};
339+
}
321340
self.poll(cx)
322341
}
323342
State::Finalizing { mut fut } => match fut.as_mut().poll(cx) {
324343
Poll::Ready(output) => {
325-
let (state, ctx, tls_transcript) = output?;
326-
let InnerState { vm, keys, .. } = state;
344+
let (inner, ctx, tls_transcript) = output?;
345+
let InnerState { vm, keys, .. } = inner;
327346

328347
let transcript = tls_transcript
329348
.to_transcript()
@@ -366,6 +385,7 @@ struct InnerState {
366385
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
367386
keys: SessionKeys,
368387
mpc_ctrl: LeaderCtrl,
388+
closed: bool,
369389
}
370390

371391
impl InnerState {
@@ -377,23 +397,25 @@ impl InnerState {
377397

378398
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
379399
async fn client_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
380-
if self.tls.plaintext_is_empty() && self.tls.is_empty().await? {
400+
if self.tls.plaintext_is_empty() && self.tls.is_empty().await? && !self.closed {
381401
if let Err(e) = self.tls.send_close_notify().await {
382402
warn!("failed to send close_notify to server: {}", e);
383403
};
384404

385405
self.mpc_ctrl.stop().await?;
406+
self.closed = true;
386407
debug!("closed connection");
387408
}
388409
self.run().await
389410
}
390411

391412
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
392413
async fn server_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
393-
if self.tls.plaintext_is_empty() && self.tls.is_empty().await? {
414+
if self.tls.plaintext_is_empty() && self.tls.is_empty().await? && !self.closed {
394415
self.tls.server_closed().await?;
395416

396417
self.mpc_ctrl.stop().await?;
418+
self.closed = true;
397419
debug!("closed connection");
398420
}
399421
self.run().await

0 commit comments

Comments
 (0)