Skip to content

Commit 78e5e97

Browse files
committed
fix(optimism): Reconnect if ws stream ends in WsFlashBlockStream
1 parent b1e1932 commit 78e5e97

File tree

1 file changed

+44
-24
lines changed

1 file changed

+44
-24
lines changed

crates/optimism/flashblocks/src/ws/stream.rs

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ where
7272
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
7373
let this = self.get_mut();
7474

75-
loop {
75+
'start: loop {
7676
if this.state == State::Initial {
7777
this.connect();
7878
}
@@ -104,7 +104,9 @@ where
104104
.expect("Stream state should be unreachable without stream")
105105
.poll_next_unpin(cx))
106106
else {
107-
return Poll::Ready(None);
107+
this.state = State::Initial;
108+
109+
continue 'start;
108110
};
109111

110112
match msg {
@@ -244,6 +246,14 @@ mod tests {
244246
#[derive(Default)]
245247
struct FakeStream(Vec<Result<Message, Error>>);
246248

249+
impl FakeStream {
250+
fn new(mut messages: Vec<Result<Message, Error>>) -> Self {
251+
messages.reverse();
252+
253+
Self(messages)
254+
}
255+
}
256+
247257
impl Clone for FakeStream {
248258
fn clone(&self) -> Self {
249259
Self(
@@ -353,7 +363,7 @@ mod tests {
353363

354364
impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnector {
355365
fn from(value: T) -> Self {
356-
Self(FakeStream(value.into_iter().collect()))
366+
Self(FakeStream::new(value.into_iter().collect()))
357367
}
358368
}
359369

@@ -371,7 +381,7 @@ mod tests {
371381

372382
impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnectorWithSink {
373383
fn from(value: T) -> Self {
374-
Self(FakeStream(value.into_iter().collect()))
384+
Self(FakeStream::new(value.into_iter().collect()))
375385
}
376386
}
377387

@@ -407,13 +417,8 @@ mod tests {
407417
Ok(Message::Binary(Bytes::from(compressed)))
408418
}
409419

410-
#[test_case::test_case(to_json_message; "json")]
411-
#[test_case::test_case(to_brotli_message; "brotli")]
412-
#[tokio::test]
413-
async fn test_stream_decodes_messages_successfully(
414-
to_message: impl Fn(&FlashBlock) -> Result<Message, Error>,
415-
) {
416-
let flashblocks = [FlashBlock {
420+
fn flashblock() -> FlashBlock {
421+
FlashBlock {
417422
payload_id: Default::default(),
418423
index: 0,
419424
base: Some(ExecutionPayloadBaseV1 {
@@ -429,13 +434,21 @@ mod tests {
429434
}),
430435
diff: Default::default(),
431436
metadata: Default::default(),
432-
}];
437+
}
438+
}
433439

434-
let messages = FakeConnector::from(flashblocks.iter().map(to_message));
440+
#[test_case::test_case(to_json_message; "json")]
441+
#[test_case::test_case(to_brotli_message; "brotli")]
442+
#[tokio::test]
443+
async fn test_stream_decodes_messages_successfully(
444+
to_message: impl Fn(&FlashBlock) -> Result<Message, Error>,
445+
) {
446+
let flashblocks = [flashblock()];
447+
let connector = FakeConnector::from(flashblocks.iter().map(to_message));
435448
let ws_url = "http://localhost".parse().unwrap();
436-
let stream = WsFlashBlockStream::with_connector(ws_url, messages);
449+
let stream = WsFlashBlockStream::with_connector(ws_url, connector);
437450

438-
let actual_messages: Vec<_> = stream.map(Result::unwrap).collect().await;
451+
let actual_messages: Vec<_> = stream.take(1).map(Result::unwrap).collect().await;
439452
let expected_messages = flashblocks.to_vec();
440453

441454
assert_eq!(actual_messages, expected_messages);
@@ -445,20 +458,26 @@ mod tests {
445458
#[test_case::test_case(Message::Frame(Frame::pong(b"test".as_slice())); "frame")]
446459
#[tokio::test]
447460
async fn test_stream_ignores_unexpected_message(message: Message) {
448-
let messages = FakeConnector::from([Ok(message)]);
461+
let flashblock = flashblock();
462+
let connector = FakeConnector::from([Ok(message), to_json_message(&flashblock)]);
449463
let ws_url = "http://localhost".parse().unwrap();
450-
let mut stream = WsFlashBlockStream::with_connector(ws_url, messages);
451-
assert!(stream.next().await.is_none());
464+
let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
465+
466+
let expected_message = flashblock;
467+
let actual_message =
468+
stream.next().await.expect("Binary message should not be ignored").unwrap();
469+
470+
assert_eq!(actual_message, expected_message)
452471
}
453472

454473
#[tokio::test]
455474
async fn test_stream_passes_errors_through() {
456-
let messages = FakeConnector::from([Err(Error::AttackAttempt)]);
475+
let connector = FakeConnector::from([Err(Error::AttackAttempt)]);
457476
let ws_url = "http://localhost".parse().unwrap();
458-
let stream = WsFlashBlockStream::with_connector(ws_url, messages);
477+
let stream = WsFlashBlockStream::with_connector(ws_url, connector);
459478

460479
let actual_messages: Vec<_> =
461-
stream.map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
480+
stream.take(1).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
462481
let expected_messages = vec!["Attack attempt detected".to_owned()];
463482

464483
assert_eq!(actual_messages, expected_messages);
@@ -468,9 +487,9 @@ mod tests {
468487
async fn test_connect_error_causes_retries() {
469488
let tries = 3;
470489
let error_msg = "test".to_owned();
471-
let messages = FailingConnector(error_msg.clone());
490+
let connector = FailingConnector(error_msg.clone());
472491
let ws_url = "http://localhost".parse().unwrap();
473-
let stream = WsFlashBlockStream::with_connector(ws_url, messages);
492+
let stream = WsFlashBlockStream::with_connector(ws_url, connector);
474493

475494
let actual_errors: Vec<_> =
476495
stream.take(tries).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
@@ -483,7 +502,8 @@ mod tests {
483502
async fn test_stream_pongs_ping() {
484503
const ECHO: [u8; 3] = [1u8, 2, 3];
485504

486-
let messages = [Ok(Message::Ping(Bytes::from_static(&ECHO)))];
505+
let flashblock = flashblock();
506+
let messages = [Ok(Message::Ping(Bytes::from_static(&ECHO))), to_json_message(&flashblock)];
487507
let connector = FakeConnectorWithSink::from(messages);
488508
let ws_url = "http://localhost".parse().unwrap();
489509
let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);

0 commit comments

Comments
 (0)