Skip to content

Commit 2ddc731

Browse files
committed
WIP: tidying up state room...
1 parent bf52da0 commit 2ddc731

File tree

1 file changed

+130
-63
lines changed
  • crates/tlsn/src/prover/client

1 file changed

+130
-63
lines changed

crates/tlsn/src/prover/client/mpc.rs

Lines changed: 130 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -30,35 +30,29 @@ pub(crate) struct MpcTlsClient {
3030
}
3131

3232
enum State {
33-
ActiveIdle {
34-
inner: InnerState,
33+
Active {
34+
mpc: Pin<MpcFuture>,
35+
inner: Box<InnerState>,
3536
},
36-
ActiveBusy {
37+
Busy {
38+
mpc: Pin<MpcFuture>,
3739
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
3840
},
39-
SetDecryptIdle {
41+
SetDecrypt {
4042
enable: bool,
41-
inner: InnerState,
43+
mpc: Pin<MpcFuture>,
44+
inner: Box<InnerState>,
4245
},
43-
SetDecryptBusy {
44-
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
45-
},
46-
ClientCloseIdle {
47-
inner: InnerState,
48-
},
49-
ClientCloseBusy {
50-
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
51-
},
52-
ServerCloseIdle {
53-
inner: InnerState,
54-
},
55-
ServerCloseBusy {
56-
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
46+
ClientClose {
47+
mpc: Pin<MpcFuture>,
48+
inner: Box<InnerState>,
5749
},
58-
ClosingIdle {
59-
inner: InnerState,
50+
ServerClose {
51+
mpc: Pin<MpcFuture>,
52+
inner: Box<InnerState>,
6053
},
61-
ClosingBusy {
54+
Closing {
55+
mpc: Pin<MpcFuture>,
6256
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
6357
},
6458
Finalizing {
@@ -85,31 +79,33 @@ impl MpcTlsClient {
8579
tls: ClientConnection,
8680
decrypt: bool,
8781
) -> Self {
82+
let inner = InnerState {
83+
span,
84+
tls,
85+
vm,
86+
keys,
87+
mpc_ctrl,
88+
};
89+
8890
Self {
89-
state: State::ActiveIdle {
90-
inner: InnerState {
91-
span,
92-
tls,
93-
mpc: Box::into_pin(mpc),
94-
vm,
95-
keys,
96-
mpc_ctrl,
97-
decrypt,
98-
},
91+
decrypt,
92+
state: State::Active {
93+
mpc: Box::into_pin(mpc),
94+
inner: Box::new(inner),
9995
},
10096
}
10197
}
10298

10399
fn inner_client_mut(&mut self) -> Option<&mut ClientConnection> {
104-
if let State::ActiveIdle { inner } = &mut self.state {
100+
if let State::Active { inner, .. } = &mut self.state {
105101
Some(&mut inner.tls)
106102
} else {
107103
None
108104
}
109105
}
110106

111107
fn inner_client(&self) -> Option<&ClientConnection> {
112-
if let State::ActiveIdle { inner } = &self.state {
108+
if let State::Active { inner, .. } = &self.state {
113109
Some(&inner.tls)
114110
} else {
115111
None
@@ -118,17 +114,17 @@ impl MpcTlsClient {
118114

119115
fn set_status(&mut self, status: Status) -> bool {
120116
match std::mem::replace(&mut self.state, State::Error) {
121-
State::ActiveIdle { inner } => match status {
117+
State::Active { inner, mpc } => match status {
122118
Status::SetDecrypt(enable) => {
123-
self.state = State::SetDecryptIdle { enable, inner };
119+
self.state = State::SetDecrypt { enable, mpc, inner };
124120
true
125121
}
126122
Status::ClientClose => {
127-
self.state = State::ClientCloseIdle { inner };
123+
self.state = State::ClientClose { mpc, inner };
128124
true
129125
}
130126
Status::ServerClose => {
131-
self.state = State::ServerCloseIdle { inner };
127+
self.state = State::ServerClose { mpc, inner };
132128
true
133129
}
134130
},
@@ -256,30 +252,74 @@ impl TlsClient for MpcTlsClient {
256252

257253
fn poll(&mut self, cx: &mut std::task::Context) -> Poll<Result<TlsOutput, Self::Error>> {
258254
match std::mem::replace(&mut self.state, State::Error) {
259-
State::ActiveIdle { inner } => {
260-
trace!("inner client is active idle");
261-
let _ = inner.mpc.as_mut().poll(cx)?;
262-
self.state = State::ActiveBusy { fut: inner.run() }
255+
State::Active { mpc, inner } => {
256+
trace!("inner client is active");
257+
258+
self.state = State::Busy {
259+
mpc,
260+
fut: Box::pin(inner.run()),
261+
};
262+
self.poll(cx)
263263
}
264-
State::ActiveBusy { fut } => match fut.as_mut().poll(cx) {
265-
Poll::Ready(res) => {
266-
let inner = res?;
267-
self.state.decrypt = *inner.decrypt;
268-
self.state = State::ActiveIdle { inner };
264+
State::Busy { mut mpc, mut fut } => {
265+
trace!("inner client is busy");
266+
let _ = mpc.as_mut().poll(cx)?;
267+
268+
match fut.as_mut().poll(cx) {
269+
Poll::Ready(res) => {
270+
let inner = res?;
271+
self.state = State::Active { mpc, inner };
272+
}
273+
Poll::Pending => self.state = State::Busy { mpc, fut },
269274
}
270-
Poll::Pending => todo!(),
271-
},
272-
State::SetDecryptIdle {
273-
enable,
274-
inner: state,
275-
} => todo!(),
276-
State::SetDecryptBusy { fut } => todo!(),
277-
State::ClientCloseIdle { inner: state } => todo!(),
278-
State::ClientCloseBusy { fut } => todo!(),
279-
State::ServerCloseIdle { inner: state } => todo!(),
280-
State::ServerCloseBusy { fut } => todo!(),
281-
State::ClosingIdle { inner: state } => todo!(),
282-
State::ClosingBusy { fut } => todo!(),
275+
Poll::Pending
276+
}
277+
State::SetDecrypt { enable, mpc, inner } => {
278+
self.state = State::Busy {
279+
mpc,
280+
fut: Box::pin(inner.set_decrypt(enable)),
281+
};
282+
self.decrypt = enable;
283+
284+
debug!("set decryption to {}", enable);
285+
self.poll(cx)
286+
}
287+
State::ClientClose { mpc, inner } => {
288+
debug!("attempting to close connection clientside");
289+
self.state = State::Closing {
290+
mpc,
291+
fut: Box::pin(inner.client_close()),
292+
};
293+
self.poll(cx)
294+
}
295+
State::ServerClose { mpc, inner } => {
296+
debug!("attempting to close connection serverside");
297+
self.state = State::Closing {
298+
mpc,
299+
fut: Box::pin(inner.server_close()),
300+
};
301+
self.poll(cx)
302+
}
303+
State::Closing { mut mpc, mut fut } => {
304+
let Poll::Ready(res) = fut.as_mut().poll(cx) else {
305+
self.state = State::Closing { mpc, fut };
306+
return Poll::Pending;
307+
};
308+
let inner = res?;
309+
let Poll::Ready(res) = mpc.as_mut().poll(cx) else {
310+
self.state = State::Closing {
311+
mpc,
312+
fut: Box::pin(inner.client_close()),
313+
};
314+
return Poll::Pending;
315+
};
316+
317+
let (ctx, transcript) = res?;
318+
self.state = State::Finalizing {
319+
fut: Box::pin(inner.finalize(ctx, transcript)),
320+
};
321+
self.poll(cx)
322+
}
283323
State::Finalizing { mut fut } => match fut.as_mut().poll(cx) {
284324
Poll::Ready(output) => {
285325
let (state, ctx, tls_transcript) = output?;
@@ -323,18 +363,45 @@ impl TlsClient for MpcTlsClient {
323363
struct InnerState {
324364
span: Span,
325365
tls: ClientConnection,
326-
mpc: Pin<MpcFuture>,
327366
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
328367
keys: SessionKeys,
329368
mpc_ctrl: LeaderCtrl,
330369
}
331370

332371
impl InnerState {
333372
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
373+
async fn set_decrypt(self: Box<Self>, enable: bool) -> Result<Box<Self>, ProverError> {
374+
self.mpc_ctrl.enable_decryption(enable).await?;
375+
self.run().await
376+
}
377+
378+
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
379+
async fn client_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
380+
if self.tls.plaintext_is_empty() && self.tls.is_empty().await? {
381+
if let Err(e) = self.tls.send_close_notify().await {
382+
warn!("failed to send close_notify to server: {}", e);
383+
};
384+
385+
self.mpc_ctrl.stop().await?;
386+
debug!("closed connection");
387+
}
388+
self.run().await
389+
}
390+
391+
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
392+
async fn server_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
393+
if self.tls.plaintext_is_empty() && self.tls.is_empty().await? {
394+
self.tls.server_closed().await?;
395+
396+
self.mpc_ctrl.stop().await?;
397+
debug!("closed connection");
398+
}
399+
self.run().await
400+
}
401+
402+
#[instrument(parent = &self.span, level = "trace", skip_all, err)]
334403
async fn run(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
335-
trace!("processing new packets");
336404
self.tls.process_new_packets().await?;
337-
338405
Ok(self)
339406
}
340407

0 commit comments

Comments
 (0)