Skip to content

Commit e583aea

Browse files
committed
fix(optimism): Reconnect if ws stream ends in WsFlashBlockStream
1 parent 36e39eb commit e583aea

File tree

1 file changed

+47
-24
lines changed

1 file changed

+47
-24
lines changed

crates/optimism/flashblocks/src/ws/stream.rs

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ where
7272
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
7373
let this = self.get_mut();
7474

75-
loop {
75+
'start: loop {
7676
if this.state == State::Initial {
7777
this.connect();
7878
}
@@ -104,7 +104,9 @@ where
104104
.expect("Stream state should be unreachable without stream")
105105
.poll_next_unpin(cx))
106106
else {
107-
return Poll::Ready(None);
107+
this.state = State::Initial;
108+
109+
continue 'start;
108110
};
109111

110112
match msg {
@@ -241,6 +243,14 @@ mod tests {
241243
#[derive(Default)]
242244
struct FakeStream(Vec<Result<Message, Error>>);
243245

246+
impl FakeStream {
247+
fn new(mut messages: Vec<Result<Message, Error>>) -> Self {
248+
messages.reverse();
249+
250+
Self(messages)
251+
}
252+
}
253+
244254
impl Clone for FakeStream {
245255
fn clone(&self) -> Self {
246256
Self(
@@ -350,7 +360,7 @@ mod tests {
350360

351361
impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnector {
352362
fn from(value: T) -> Self {
353-
Self(FakeStream(value.into_iter().collect()))
363+
Self(FakeStream::new(value.into_iter().collect()))
354364
}
355365
}
356366

@@ -368,7 +378,7 @@ mod tests {
368378

369379
impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnectorWithSink {
370380
fn from(value: T) -> Self {
371-
Self(FakeStream(value.into_iter().collect()))
381+
Self(FakeStream::new(value.into_iter().collect()))
372382
}
373383
}
374384

@@ -404,13 +414,8 @@ mod tests {
404414
Ok(Message::Binary(Bytes::from(compressed)))
405415
}
406416

407-
#[test_case::test_case(to_json_message; "json")]
408-
#[test_case::test_case(to_brotli_message; "brotli")]
409-
#[tokio::test]
410-
async fn test_stream_decodes_messages_successfully(
411-
to_message: impl Fn(&FlashBlock) -> Result<Message, Error>,
412-
) {
413-
let flashblocks = [FlashBlock {
417+
fn flashblock() -> FlashBlock {
418+
FlashBlock {
414419
payload_id: Default::default(),
415420
index: 0,
416421
base: Some(ExecutionPayloadBaseV1 {
@@ -426,34 +431,51 @@ mod tests {
426431
}),
427432
diff: Default::default(),
428433
metadata: Default::default(),
429-
}];
434+
}
435+
}
430436

431-
let messages = FakeConnector::from(flashblocks.iter().map(to_message));
437+
#[test_case::test_case(to_json_message; "json")]
438+
#[test_case::test_case(to_brotli_message; "brotli")]
439+
#[tokio::test]
440+
async fn test_stream_decodes_messages_successfully(
441+
to_message: impl Fn(&FlashBlock) -> Result<Message, Error>,
442+
) {
443+
let flashblocks = [flashblock()];
444+
let connector = FakeConnector::from(flashblocks.iter().map(to_message));
432445
let ws_url = "http://localhost".parse().unwrap();
433-
let stream = WsFlashBlockStream::with_connector(ws_url, messages);
446+
let stream = WsFlashBlockStream::with_connector(ws_url, connector);
434447

435-
let actual_messages: Vec<_> = stream.map(Result::unwrap).collect().await;
448+
let actual_messages: Vec<_> = stream.take(1).map(Result::unwrap).collect().await;
436449
let expected_messages = flashblocks.to_vec();
437450

438451
assert_eq!(actual_messages, expected_messages);
439452
}
440453

441454
#[tokio::test]
442455
async fn test_stream_ignores_non_binary_message() {
443-
let messages = FakeConnector::from([Ok(Message::Text(Utf8Bytes::from("test")))]);
456+
let flashblock = flashblock();
457+
let connector = FakeConnector::from([
458+
Ok(Message::Text(Utf8Bytes::from("test"))),
459+
to_json_message(&flashblock),
460+
]);
444461
let ws_url = "http://localhost".parse().unwrap();
445-
let mut stream = WsFlashBlockStream::with_connector(ws_url, messages);
446-
assert!(stream.next().await.is_none());
462+
let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
463+
464+
let expected_message = flashblock;
465+
let actual_message =
466+
stream.next().await.expect("Binary message should not be ignored").unwrap();
467+
468+
assert_eq!(actual_message, expected_message)
447469
}
448470

449471
#[tokio::test]
450472
async fn test_stream_passes_errors_through() {
451-
let messages = FakeConnector::from([Err(Error::AttackAttempt)]);
473+
let connector = FakeConnector::from([Err(Error::AttackAttempt)]);
452474
let ws_url = "http://localhost".parse().unwrap();
453-
let stream = WsFlashBlockStream::with_connector(ws_url, messages);
475+
let stream = WsFlashBlockStream::with_connector(ws_url, connector);
454476

455477
let actual_messages: Vec<_> =
456-
stream.map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
478+
stream.take(1).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
457479
let expected_messages = vec!["Attack attempt detected".to_owned()];
458480

459481
assert_eq!(actual_messages, expected_messages);
@@ -463,9 +485,9 @@ mod tests {
463485
async fn test_connect_error_causes_retries() {
464486
let tries = 3;
465487
let error_msg = "test".to_owned();
466-
let messages = FailingConnector(error_msg.clone());
488+
let connector = FailingConnector(error_msg.clone());
467489
let ws_url = "http://localhost".parse().unwrap();
468-
let stream = WsFlashBlockStream::with_connector(ws_url, messages);
490+
let stream = WsFlashBlockStream::with_connector(ws_url, connector);
469491

470492
let actual_errors: Vec<_> =
471493
stream.take(tries).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
@@ -478,7 +500,8 @@ mod tests {
478500
async fn test_stream_pongs_ping() {
479501
const ECHO: [u8; 3] = [1u8, 2, 3];
480502

481-
let messages = [Ok(Message::Ping(Bytes::from_static(&ECHO)))];
503+
let flashblock = flashblock();
504+
let messages = [Ok(Message::Ping(Bytes::from_static(&ECHO))), to_json_message(&flashblock)];
482505
let connector = FakeConnectorWithSink::from(messages);
483506
let ws_url = "http://localhost".parse().unwrap();
484507
let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);

0 commit comments

Comments
 (0)