Skip to content

Commit 072a454

Browse files
committed
refactor(tls-client): make tls_client compatible with a synchronous API (#1027)
* refactor(tls-client): make `write_plaintext` sync and remove async api * restore `complete_io` * do not potentially block in `write_all_plaintext`
1 parent c51331d commit 072a454

File tree

2 files changed

+52
-114
lines changed

2 files changed

+52
-114
lines changed

crates/tls/client/src/conn.rs

Lines changed: 19 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,9 @@ impl ConnectionCommon {
457457
return Err(Error::CorruptMessage);
458458
}
459459

460+
// Process outgoing plaintext buffer and encrypt messages.
461+
self.flush_plaintext().await?;
462+
460463
// Process new messages.
461464
while let Some(msg) = self.message_deframer.frames.pop_front() {
462465
// If we're not decrypting yet, we process it immediately. Otherwise it will be
@@ -508,25 +511,22 @@ impl ConnectionCommon {
508511
Ok(state)
509512
}
510513

511-
/// Write buffer into connection.
512-
pub async fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
513-
if let Ok(st) = &mut self.state {
514-
st.perhaps_write_key_update(&mut self.common_state).await;
514+
/// Writes plaintext `buf` into an internal buffer. May not fully process the
515+
/// whole buffer and returns the processed length.
516+
pub fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
517+
if buf.is_empty() {
518+
// Don't send empty fragments.
519+
return Ok(0);
515520
}
516-
self.common_state.send_some_plaintext(buf).await
521+
522+
let len = self.sendable_plaintext.append_limited_copy(buf);
523+
Ok(len)
517524
}
518525

519-
/// Write entire buffer into connection.
520-
pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
521-
let mut pos = 0;
522-
while pos < buf.len() {
523-
pos += self.write_plaintext(&buf[pos..]).await?;
524-
}
525-
self.backend.flush().await?;
526-
while let Some(msg) = self.backend.next_outgoing().await? {
527-
self.queue_tls_message(msg);
528-
}
529-
Ok(pos)
526+
/// Writes the entire plaintext `buf` into an internal buffer.
527+
pub fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<(), Error> {
528+
self.sendable_plaintext.append(buf.to_vec());
529+
Ok(())
530530
}
531531

532532
/// Read TLS content from `rd`. This method does internal
@@ -782,15 +782,6 @@ impl CommonState {
782782
}
783783
}
784784

785-
/// Send plaintext application data, fragmenting and
786-
/// encrypting it as it goes out.
787-
///
788-
/// If internal buffers are too small, this function will not accept
789-
/// all the data.
790-
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> Result<usize, Error> {
791-
self.send_plain(data, Limit::Yes).await
792-
}
793-
794785
// Changing the keys must not span any fragmented handshake
795786
// messages. Otherwise the defragmented messages will have
796787
// been protected with two different record layer protections,
@@ -931,32 +922,6 @@ impl CommonState {
931922
self.sendable_tls.write_to_async(wr).await
932923
}
933924

934-
/// Encrypt and send some plaintext `data`. `limit` controls
935-
/// whether the per-connection buffer limits apply.
936-
///
937-
/// Returns the number of bytes written from `data`: this might
938-
/// be less than `data.len()` if buffer limits were exceeded.
939-
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> Result<usize, Error> {
940-
if !self.may_send_application_data {
941-
// If we haven't completed handshaking, buffer
942-
// plaintext to send once we do.
943-
let len = match limit {
944-
Limit::Yes => self.sendable_plaintext.append_limited_copy(data),
945-
Limit::No => self.sendable_plaintext.append(data.to_vec()),
946-
};
947-
return Ok(len);
948-
}
949-
950-
debug_assert!(self.record_layer.is_encrypting());
951-
952-
if data.is_empty() {
953-
// Don't send empty fragments.
954-
return Ok(0);
955-
}
956-
957-
self.send_appdata_encrypt(data, limit).await
958-
}
959-
960925
pub(crate) async fn start_outgoing_traffic(&mut self) -> Result<(), Error> {
961926
self.may_send_application_data = true;
962927
self.flush_plaintext().await
@@ -1012,15 +977,14 @@ impl CommonState {
1012977
self.sendable_tls.set_limit(limit);
1013978
}
1014979

1015-
/// Send any buffered plaintext. Plaintext is buffered if
1016-
/// written during handshake.
1017-
async fn flush_plaintext(&mut self) -> Result<(), Error> {
980+
/// Send and encrypt any buffered plaintext. Does nothing during handshake.
981+
pub async fn flush_plaintext(&mut self) -> Result<(), Error> {
1018982
if !self.may_send_application_data {
1019983
return Ok(());
1020984
}
1021985

1022986
while let Some(buf) = self.sendable_plaintext.pop() {
1023-
self.send_plain(&buf, Limit::No).await?;
987+
self.send_appdata_encrypt(&buf, Limit::No).await?;
1024988
}
1025989

1026990
Ok(())

crates/tls/client/tests/api.rs

Lines changed: 33 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ async fn servered_client_data_sent() {
247247
let (mut client, mut server) =
248248
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
249249

250-
assert_eq!(5, client.write_plaintext(b"hello").await.unwrap());
250+
assert_eq!(5, client.write_plaintext(b"hello").unwrap());
251+
client.flush_plaintext().await.unwrap();
251252

252253
do_handshake(&mut client, &mut server).await;
253254
send(&mut client, &mut server);
@@ -286,7 +287,7 @@ async fn servered_both_data_sent() {
286287
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
287288

288289
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
289-
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
290+
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
290291

291292
do_handshake(&mut client, &mut server).await;
292293

@@ -432,7 +433,7 @@ async fn server_close_notify() {
432433

433434
// check that alerts don't overtake appdata
434435
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
435-
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
436+
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
436437
server.send_close_notify();
437438

438439
receive(&mut server, &mut client);
@@ -460,7 +461,8 @@ async fn client_close_notify() {
460461

461462
// check that alerts don't overtake appdata
462463
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
463-
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
464+
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
465+
client.flush_plaintext().await.unwrap();
464466
client.send_close_notify().await.unwrap();
465467

466468
send(&mut client, &mut server);
@@ -487,7 +489,7 @@ async fn server_closes_uncleanly() {
487489

488490
// check that unclean EOF reporting does not overtake appdata
489491
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
490-
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
492+
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
491493

492494
receive(&mut server, &mut client);
493495
transfer_eof(&mut client);
@@ -518,7 +520,7 @@ async fn client_closes_uncleanly() {
518520

519521
// check that unclean EOF reporting does not overtake appdata
520522
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
521-
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
523+
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
522524
client.process_new_packets().await.unwrap();
523525

524526
send(&mut client, &mut server);
@@ -900,20 +902,9 @@ async fn client_respects_buffer_limit_pre_handshake() {
900902

901903
client.set_buffer_limit(Some(32));
902904

903-
assert_eq!(
904-
client
905-
.write_plaintext(b"01234567890123456789")
906-
.await
907-
.unwrap(),
908-
20
909-
);
910-
assert_eq!(
911-
client
912-
.write_plaintext(b"01234567890123456789")
913-
.await
914-
.unwrap(),
915-
12
916-
);
905+
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
906+
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 12);
907+
client.flush_plaintext().await.unwrap();
917908

918909
do_handshake(&mut client, &mut server).await;
919910
send(&mut client, &mut server);
@@ -953,20 +944,9 @@ async fn client_respects_buffer_limit_post_handshake() {
953944
do_handshake(&mut client, &mut server).await;
954945
client.set_buffer_limit(Some(48));
955946

956-
assert_eq!(
957-
client
958-
.write_plaintext(b"01234567890123456789")
959-
.await
960-
.unwrap(),
961-
20
962-
);
963-
assert_eq!(
964-
client
965-
.write_plaintext(b"01234567890123456789")
966-
.await
967-
.unwrap(),
968-
6
969-
);
947+
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
948+
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 6);
949+
client.flush_plaintext().await.unwrap();
970950

971951
send(&mut client, &mut server);
972952
server.process_new_packets().unwrap();
@@ -1211,14 +1191,8 @@ async fn client_complete_io_for_write() {
12111191

12121192
do_handshake(&mut client, &mut server).await;
12131193

1214-
client
1215-
.write_plaintext(b"01234567890123456789")
1216-
.await
1217-
.unwrap();
1218-
client
1219-
.write_plaintext(b"01234567890123456789")
1220-
.await
1221-
.unwrap();
1194+
client.write_plaintext(b"01234567890123456789").unwrap();
1195+
client.write_plaintext(b"01234567890123456789").unwrap();
12221196
{
12231197
let mut pipe = ServerSession::new(&mut server);
12241198
let (rdlen, wrlen) = client
@@ -1350,7 +1324,8 @@ async fn server_stream_read() {
13501324
for kt in ALL_KEY_TYPES.iter() {
13511325
let (mut client, mut server) = make_pair(*kt).await;
13521326

1353-
client.write_all_plaintext(b"world").await.unwrap();
1327+
client.write_all_plaintext(b"world").unwrap();
1328+
client.process_new_packets().await.unwrap();
13541329

13551330
{
13561331
let mut pipe = ClientSession::new(&mut client);
@@ -1366,7 +1341,8 @@ async fn server_streamowned_read() {
13661341
for kt in ALL_KEY_TYPES.iter() {
13671342
let (mut client, server) = make_pair(*kt).await;
13681343

1369-
client.write_all_plaintext(b"world").await.unwrap();
1344+
client.write_all_plaintext(b"world").unwrap();
1345+
client.process_new_packets().await.unwrap();
13701346

13711347
{
13721348
let pipe = ClientSession::new(&mut client);
@@ -1385,7 +1361,9 @@ async fn server_streamowned_read() {
13851361
// errkind: io::ErrorKind::ConnectionAborted,
13861362
// after: 0,
13871363
// };
1388-
// client.write_all_plaintext(b"hello").await.unwrap();
1364+
// client.write_all_plaintext(b"hello").unwrap();
1365+
// client.process_new_packets().await.unwrap();
1366+
//
13891367
// let mut client_stream = Stream::new(&mut client, &mut pipe);
13901368
// let rc = client_stream.write(b"world");
13911369
// assert!(rc.is_err());
@@ -1402,7 +1380,9 @@ async fn server_streamowned_read() {
14021380
// errkind: io::ErrorKind::ConnectionAborted,
14031381
// after: 1,
14041382
// };
1405-
// client.write_all_plaintext(b"hello").await.unwrap();
1383+
// client.write_all_plaintext(b"hello").unwrap();
1384+
// client.process_new_packets().await.unwrap();
1385+
//
14061386
// let mut client_stream = Stream::new(&mut client, &mut pipe);
14071387
// let rc = client_stream.write(b"world");
14081388
// assert_eq!(format!("{:?}", rc), "Ok(5)");
@@ -1900,14 +1880,9 @@ async fn servered_write_for_client_appdata() {
19001880
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
19011881
do_handshake(&mut client, &mut server).await;
19021882

1903-
client
1904-
.write_all_plaintext(b"01234567890123456789")
1905-
.await
1906-
.unwrap();
1907-
client
1908-
.write_all_plaintext(b"01234567890123456789")
1909-
.await
1910-
.unwrap();
1883+
client.write_all_plaintext(b"01234567890123456789").unwrap();
1884+
client.write_all_plaintext(b"01234567890123456789").unwrap();
1885+
client.process_new_packets().await.unwrap();
19111886
{
19121887
let mut pipe = ServerSession::new(&mut server);
19131888
let wrlen = client.write_tls(&mut pipe).unwrap();
@@ -2019,11 +1994,10 @@ async fn servered_write_for_server_handshake_no_half_rtt_by_default() {
20191994
async fn servered_write_for_client_handshake() {
20201995
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
20211996

2022-
client
2023-
.write_all_plaintext(b"01234567890123456789")
2024-
.await
2025-
.unwrap();
2026-
client.write_all_plaintext(b"0123456789").await.unwrap();
1997+
client.write_all_plaintext(b"01234567890123456789").unwrap();
1998+
client.write_all_plaintext(b"0123456789").unwrap();
1999+
client.process_new_packets().await.unwrap();
2000+
20272001
{
20282002
let mut pipe = ServerSession::new(&mut server);
20292003
let wrlen = client.write_tls(&mut pipe).unwrap();

0 commit comments

Comments
 (0)