7272 fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
7373 let this = self . get_mut ( ) ;
7474
75- loop {
75+ ' start : loop {
7676 if this. state == State :: Initial {
7777 this. connect ( ) ;
7878 }
@@ -104,7 +104,9 @@ where
104104 . expect( "Stream state should be unreachable without stream" )
105105 . poll_next_unpin( cx) )
106106 else {
107- return Poll :: Ready ( None ) ;
107+ this. state = State :: Initial ;
108+
109+ continue ' start;
108110 } ;
109111
110112 match msg {
@@ -244,6 +246,14 @@ mod tests {
244246 #[ derive( Default ) ]
245247 struct FakeStream ( Vec < Result < Message , Error > > ) ;
246248
249+ impl FakeStream {
250+ fn new ( mut messages : Vec < Result < Message , Error > > ) -> Self {
251+ messages. reverse ( ) ;
252+
253+ Self ( messages)
254+ }
255+ }
256+
247257 impl Clone for FakeStream {
248258 fn clone ( & self ) -> Self {
249259 Self (
@@ -353,7 +363,7 @@ mod tests {
353363
354364 impl < T : IntoIterator < Item = Result < Message , Error > > > From < T > for FakeConnector {
355365 fn from ( value : T ) -> Self {
356- Self ( FakeStream ( value. into_iter ( ) . collect ( ) ) )
366+ Self ( FakeStream :: new ( value. into_iter ( ) . collect ( ) ) )
357367 }
358368 }
359369
@@ -371,7 +381,7 @@ mod tests {
371381
372382 impl < T : IntoIterator < Item = Result < Message , Error > > > From < T > for FakeConnectorWithSink {
373383 fn from ( value : T ) -> Self {
374- Self ( FakeStream ( value. into_iter ( ) . collect ( ) ) )
384+ Self ( FakeStream :: new ( value. into_iter ( ) . collect ( ) ) )
375385 }
376386 }
377387
@@ -407,13 +417,8 @@ mod tests {
407417 Ok ( Message :: Binary ( Bytes :: from ( compressed) ) )
408418 }
409419
410- #[ test_case:: test_case( to_json_message; "json" ) ]
411- #[ test_case:: test_case( to_brotli_message; "brotli" ) ]
412- #[ tokio:: test]
413- async fn test_stream_decodes_messages_successfully (
414- to_message : impl Fn ( & FlashBlock ) -> Result < Message , Error > ,
415- ) {
416- let flashblocks = [ FlashBlock {
420+ fn flashblock ( ) -> FlashBlock {
421+ FlashBlock {
417422 payload_id : Default :: default ( ) ,
418423 index : 0 ,
419424 base : Some ( ExecutionPayloadBaseV1 {
@@ -429,13 +434,21 @@ mod tests {
429434 } ) ,
430435 diff : Default :: default ( ) ,
431436 metadata : Default :: default ( ) ,
432- } ] ;
437+ }
438+ }
433439
434- let messages = FakeConnector :: from ( flashblocks. iter ( ) . map ( to_message) ) ;
440+ #[ test_case:: test_case( to_json_message; "json" ) ]
441+ #[ test_case:: test_case( to_brotli_message; "brotli" ) ]
442+ #[ tokio:: test]
443+ async fn test_stream_decodes_messages_successfully (
444+ to_message : impl Fn ( & FlashBlock ) -> Result < Message , Error > ,
445+ ) {
446+ let flashblocks = [ flashblock ( ) ] ;
447+ let connector = FakeConnector :: from ( flashblocks. iter ( ) . map ( to_message) ) ;
435448 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
436- let stream = WsFlashBlockStream :: with_connector ( ws_url, messages ) ;
449+ let stream = WsFlashBlockStream :: with_connector ( ws_url, connector ) ;
437450
438- let actual_messages: Vec < _ > = stream. map ( Result :: unwrap) . collect ( ) . await ;
451+ let actual_messages: Vec < _ > = stream. take ( 1 ) . map ( Result :: unwrap) . collect ( ) . await ;
439452 let expected_messages = flashblocks. to_vec ( ) ;
440453
441454 assert_eq ! ( actual_messages, expected_messages) ;
@@ -445,20 +458,26 @@ mod tests {
445458 #[ test_case:: test_case( Message :: Frame ( Frame :: pong( b"test" . as_slice( ) ) ) ; "frame" ) ]
446459 #[ tokio:: test]
447460 async fn test_stream_ignores_unexpected_message ( message : Message ) {
448- let messages = FakeConnector :: from ( [ Ok ( message) ] ) ;
461+ let flashblock = flashblock ( ) ;
462+ let connector = FakeConnector :: from ( [ Ok ( message) , to_json_message ( & flashblock) ] ) ;
449463 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
450- let mut stream = WsFlashBlockStream :: with_connector ( ws_url, messages) ;
451- assert ! ( stream. next( ) . await . is_none( ) ) ;
464+ let mut stream = WsFlashBlockStream :: with_connector ( ws_url, connector) ;
465+
466+ let expected_message = flashblock;
467+ let actual_message =
468+ stream. next ( ) . await . expect ( "Binary message should not be ignored" ) . unwrap ( ) ;
469+
470+ assert_eq ! ( actual_message, expected_message)
452471 }
453472
454473 #[ tokio:: test]
455474 async fn test_stream_passes_errors_through ( ) {
456- let messages = FakeConnector :: from ( [ Err ( Error :: AttackAttempt ) ] ) ;
475+ let connector = FakeConnector :: from ( [ Err ( Error :: AttackAttempt ) ] ) ;
457476 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
458- let stream = WsFlashBlockStream :: with_connector ( ws_url, messages ) ;
477+ let stream = WsFlashBlockStream :: with_connector ( ws_url, connector ) ;
459478
460479 let actual_messages: Vec < _ > =
461- stream. map ( Result :: unwrap_err) . map ( |e| format ! ( "{e}" ) ) . collect ( ) . await ;
480+ stream. take ( 1 ) . map ( Result :: unwrap_err) . map ( |e| format ! ( "{e}" ) ) . collect ( ) . await ;
462481 let expected_messages = vec ! [ "Attack attempt detected" . to_owned( ) ] ;
463482
464483 assert_eq ! ( actual_messages, expected_messages) ;
@@ -468,9 +487,9 @@ mod tests {
468487 async fn test_connect_error_causes_retries ( ) {
469488 let tries = 3 ;
470489 let error_msg = "test" . to_owned ( ) ;
471- let messages = FailingConnector ( error_msg. clone ( ) ) ;
490+ let connector = FailingConnector ( error_msg. clone ( ) ) ;
472491 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
473- let stream = WsFlashBlockStream :: with_connector ( ws_url, messages ) ;
492+ let stream = WsFlashBlockStream :: with_connector ( ws_url, connector ) ;
474493
475494 let actual_errors: Vec < _ > =
476495 stream. take ( tries) . map ( Result :: unwrap_err) . map ( |e| format ! ( "{e}" ) ) . collect ( ) . await ;
@@ -483,7 +502,8 @@ mod tests {
483502 async fn test_stream_pongs_ping ( ) {
484503 const ECHO : [ u8 ; 3 ] = [ 1u8 , 2 , 3 ] ;
485504
486- let messages = [ Ok ( Message :: Ping ( Bytes :: from_static ( & ECHO ) ) ) ] ;
505+ let flashblock = flashblock ( ) ;
506+ let messages = [ Ok ( Message :: Ping ( Bytes :: from_static ( & ECHO ) ) ) , to_json_message ( & flashblock) ] ;
487507 let connector = FakeConnectorWithSink :: from ( messages) ;
488508 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
489509 let mut stream = WsFlashBlockStream :: with_connector ( ws_url, connector) ;
0 commit comments