Skip to content

Commit c57feda

Browse files
authored
fix(optimism): Reconnect if ws stream ends in WsFlashBlockStream (#18226)
1 parent ecd1898 commit c57feda

File tree

1 file changed

+44
-24
lines changed

1 file changed

+44
-24
lines changed

crates/optimism/flashblocks/src/ws/stream.rs

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ where
7272
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
7373
let this = self.get_mut();
7474

75-
loop {
75+
'start: loop {
7676
if this.state == State::Initial {
7777
this.connect();
7878
}
@@ -104,7 +104,9 @@ where
104104
.expect("Stream state should be unreachable without stream")
105105
.poll_next_unpin(cx))
106106
else {
107-
return Poll::Ready(None);
107+
this.state = State::Initial;
108+
109+
continue 'start;
108110
};
109111

110112
match msg {
@@ -251,6 +253,14 @@ mod tests {
251253
#[derive(Default)]
252254
struct FakeStream(Vec<Result<Message, Error>>);
253255

256+
impl FakeStream {
257+
fn new(mut messages: Vec<Result<Message, Error>>) -> Self {
258+
messages.reverse();
259+
260+
Self(messages)
261+
}
262+
}
263+
254264
impl Clone for FakeStream {
255265
fn clone(&self) -> Self {
256266
Self(
@@ -360,7 +370,7 @@ mod tests {
360370

361371
impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnector {
362372
fn from(value: T) -> Self {
363-
Self(FakeStream(value.into_iter().collect()))
373+
Self(FakeStream::new(value.into_iter().collect()))
364374
}
365375
}
366376

@@ -378,7 +388,7 @@ mod tests {
378388

379389
impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnectorWithSink {
380390
fn from(value: T) -> Self {
381-
Self(FakeStream(value.into_iter().collect()))
391+
Self(FakeStream::new(value.into_iter().collect()))
382392
}
383393
}
384394

@@ -414,13 +424,8 @@ mod tests {
414424
Ok(Message::Binary(Bytes::from(compressed)))
415425
}
416426

417-
#[test_case::test_case(to_json_message; "json")]
418-
#[test_case::test_case(to_brotli_message; "brotli")]
419-
#[tokio::test]
420-
async fn test_stream_decodes_messages_successfully(
421-
to_message: impl Fn(&FlashBlock) -> Result<Message, Error>,
422-
) {
423-
let flashblocks = [FlashBlock {
427+
fn flashblock() -> FlashBlock {
428+
FlashBlock {
424429
payload_id: Default::default(),
425430
index: 0,
426431
base: Some(ExecutionPayloadBaseV1 {
@@ -436,13 +441,21 @@ mod tests {
436441
}),
437442
diff: Default::default(),
438443
metadata: Default::default(),
439-
}];
444+
}
445+
}
440446

441-
let messages = FakeConnector::from(flashblocks.iter().map(to_message));
447+
#[test_case::test_case(to_json_message; "json")]
448+
#[test_case::test_case(to_brotli_message; "brotli")]
449+
#[tokio::test]
450+
async fn test_stream_decodes_messages_successfully(
451+
to_message: impl Fn(&FlashBlock) -> Result<Message, Error>,
452+
) {
453+
let flashblocks = [flashblock()];
454+
let connector = FakeConnector::from(flashblocks.iter().map(to_message));
442455
let ws_url = "http://localhost".parse().unwrap();
443-
let stream = WsFlashBlockStream::with_connector(ws_url, messages);
456+
let stream = WsFlashBlockStream::with_connector(ws_url, connector);
444457

445-
let actual_messages: Vec<_> = stream.map(Result::unwrap).collect().await;
458+
let actual_messages: Vec<_> = stream.take(1).map(Result::unwrap).collect().await;
446459
let expected_messages = flashblocks.to_vec();
447460

448461
assert_eq!(actual_messages, expected_messages);
@@ -452,20 +465,26 @@ mod tests {
452465
#[test_case::test_case(Message::Frame(Frame::pong(b"test".as_slice())); "frame")]
453466
#[tokio::test]
454467
async fn test_stream_ignores_unexpected_message(message: Message) {
455-
let messages = FakeConnector::from([Ok(message)]);
468+
let flashblock = flashblock();
469+
let connector = FakeConnector::from([Ok(message), to_json_message(&flashblock)]);
456470
let ws_url = "http://localhost".parse().unwrap();
457-
let mut stream = WsFlashBlockStream::with_connector(ws_url, messages);
458-
assert!(stream.next().await.is_none());
471+
let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
472+
473+
let expected_message = flashblock;
474+
let actual_message =
475+
stream.next().await.expect("Binary message should not be ignored").unwrap();
476+
477+
assert_eq!(actual_message, expected_message)
459478
}
460479

461480
#[tokio::test]
462481
async fn test_stream_passes_errors_through() {
463-
let messages = FakeConnector::from([Err(Error::AttackAttempt)]);
482+
let connector = FakeConnector::from([Err(Error::AttackAttempt)]);
464483
let ws_url = "http://localhost".parse().unwrap();
465-
let stream = WsFlashBlockStream::with_connector(ws_url, messages);
484+
let stream = WsFlashBlockStream::with_connector(ws_url, connector);
466485

467486
let actual_messages: Vec<_> =
468-
stream.map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
487+
stream.take(1).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
469488
let expected_messages = vec!["Attack attempt detected".to_owned()];
470489

471490
assert_eq!(actual_messages, expected_messages);
@@ -475,9 +494,9 @@ mod tests {
475494
async fn test_connect_error_causes_retries() {
476495
let tries = 3;
477496
let error_msg = "test".to_owned();
478-
let messages = FailingConnector(error_msg.clone());
497+
let connector = FailingConnector(error_msg.clone());
479498
let ws_url = "http://localhost".parse().unwrap();
480-
let stream = WsFlashBlockStream::with_connector(ws_url, messages);
499+
let stream = WsFlashBlockStream::with_connector(ws_url, connector);
481500

482501
let actual_errors: Vec<_> =
483502
stream.take(tries).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
@@ -490,7 +509,8 @@ mod tests {
490509
async fn test_stream_pongs_ping() {
491510
const ECHO: [u8; 3] = [1u8, 2, 3];
492511

493-
let messages = [Ok(Message::Ping(Bytes::from_static(&ECHO)))];
512+
let flashblock = flashblock();
513+
let messages = [Ok(Message::Ping(Bytes::from_static(&ECHO))), to_json_message(&flashblock)];
494514
let connector = FakeConnectorWithSink::from(messages);
495515
let ws_url = "http://localhost".parse().unwrap();
496516
let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);

0 commit comments

Comments
 (0)