7272 fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
7373 let this = self . get_mut ( ) ;
7474
75- loop {
75+ ' start : loop {
7676 if this. state == State :: Initial {
7777 this. connect ( ) ;
7878 }
@@ -104,7 +104,9 @@ where
104104 . expect( "Stream state should be unreachable without stream" )
105105 . poll_next_unpin( cx) )
106106 else {
107- return Poll :: Ready ( None ) ;
107+ this. state = State :: Initial ;
108+
109+ continue ' start;
108110 } ;
109111
110112 match msg {
@@ -251,6 +253,14 @@ mod tests {
251253 #[ derive( Default ) ]
252254 struct FakeStream ( Vec < Result < Message , Error > > ) ;
253255
256+ impl FakeStream {
257+ fn new ( mut messages : Vec < Result < Message , Error > > ) -> Self {
258+ messages. reverse ( ) ;
259+
260+ Self ( messages)
261+ }
262+ }
263+
254264 impl Clone for FakeStream {
255265 fn clone ( & self ) -> Self {
256266 Self (
@@ -360,7 +370,7 @@ mod tests {
360370
361371 impl < T : IntoIterator < Item = Result < Message , Error > > > From < T > for FakeConnector {
362372 fn from ( value : T ) -> Self {
363- Self ( FakeStream ( value. into_iter ( ) . collect ( ) ) )
373+ Self ( FakeStream :: new ( value. into_iter ( ) . collect ( ) ) )
364374 }
365375 }
366376
@@ -378,7 +388,7 @@ mod tests {
378388
379389 impl < T : IntoIterator < Item = Result < Message , Error > > > From < T > for FakeConnectorWithSink {
380390 fn from ( value : T ) -> Self {
381- Self ( FakeStream ( value. into_iter ( ) . collect ( ) ) )
391+ Self ( FakeStream :: new ( value. into_iter ( ) . collect ( ) ) )
382392 }
383393 }
384394
@@ -414,13 +424,8 @@ mod tests {
414424 Ok ( Message :: Binary ( Bytes :: from ( compressed) ) )
415425 }
416426
417- #[ test_case:: test_case( to_json_message; "json" ) ]
418- #[ test_case:: test_case( to_brotli_message; "brotli" ) ]
419- #[ tokio:: test]
420- async fn test_stream_decodes_messages_successfully (
421- to_message : impl Fn ( & FlashBlock ) -> Result < Message , Error > ,
422- ) {
423- let flashblocks = [ FlashBlock {
427+ fn flashblock ( ) -> FlashBlock {
428+ FlashBlock {
424429 payload_id : Default :: default ( ) ,
425430 index : 0 ,
426431 base : Some ( ExecutionPayloadBaseV1 {
@@ -436,13 +441,21 @@ mod tests {
436441 } ) ,
437442 diff : Default :: default ( ) ,
438443 metadata : Default :: default ( ) ,
439- } ] ;
444+ }
445+ }
440446
441- let messages = FakeConnector :: from ( flashblocks. iter ( ) . map ( to_message) ) ;
447+ #[ test_case:: test_case( to_json_message; "json" ) ]
448+ #[ test_case:: test_case( to_brotli_message; "brotli" ) ]
449+ #[ tokio:: test]
450+ async fn test_stream_decodes_messages_successfully (
451+ to_message : impl Fn ( & FlashBlock ) -> Result < Message , Error > ,
452+ ) {
453+ let flashblocks = [ flashblock ( ) ] ;
454+ let connector = FakeConnector :: from ( flashblocks. iter ( ) . map ( to_message) ) ;
442455 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
443- let stream = WsFlashBlockStream :: with_connector ( ws_url, messages ) ;
456+ let stream = WsFlashBlockStream :: with_connector ( ws_url, connector ) ;
444457
445- let actual_messages: Vec < _ > = stream. map ( Result :: unwrap) . collect ( ) . await ;
458+ let actual_messages: Vec < _ > = stream. take ( 1 ) . map ( Result :: unwrap) . collect ( ) . await ;
446459 let expected_messages = flashblocks. to_vec ( ) ;
447460
448461 assert_eq ! ( actual_messages, expected_messages) ;
@@ -452,20 +465,26 @@ mod tests {
452465 #[ test_case:: test_case( Message :: Frame ( Frame :: pong( b"test" . as_slice( ) ) ) ; "frame" ) ]
453466 #[ tokio:: test]
454467 async fn test_stream_ignores_unexpected_message ( message : Message ) {
455- let messages = FakeConnector :: from ( [ Ok ( message) ] ) ;
468+ let flashblock = flashblock ( ) ;
469+ let connector = FakeConnector :: from ( [ Ok ( message) , to_json_message ( & flashblock) ] ) ;
456470 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
457- let mut stream = WsFlashBlockStream :: with_connector ( ws_url, messages) ;
458- assert ! ( stream. next( ) . await . is_none( ) ) ;
471+ let mut stream = WsFlashBlockStream :: with_connector ( ws_url, connector) ;
472+
473+ let expected_message = flashblock;
474+ let actual_message =
475+ stream. next ( ) . await . expect ( "Binary message should not be ignored" ) . unwrap ( ) ;
476+
477+ assert_eq ! ( actual_message, expected_message)
459478 }
460479
461480 #[ tokio:: test]
462481 async fn test_stream_passes_errors_through ( ) {
463- let messages = FakeConnector :: from ( [ Err ( Error :: AttackAttempt ) ] ) ;
482+ let connector = FakeConnector :: from ( [ Err ( Error :: AttackAttempt ) ] ) ;
464483 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
465- let stream = WsFlashBlockStream :: with_connector ( ws_url, messages ) ;
484+ let stream = WsFlashBlockStream :: with_connector ( ws_url, connector ) ;
466485
467486 let actual_messages: Vec < _ > =
468- stream. map ( Result :: unwrap_err) . map ( |e| format ! ( "{e}" ) ) . collect ( ) . await ;
487+ stream. take ( 1 ) . map ( Result :: unwrap_err) . map ( |e| format ! ( "{e}" ) ) . collect ( ) . await ;
469488 let expected_messages = vec ! [ "Attack attempt detected" . to_owned( ) ] ;
470489
471490 assert_eq ! ( actual_messages, expected_messages) ;
@@ -475,9 +494,9 @@ mod tests {
475494 async fn test_connect_error_causes_retries ( ) {
476495 let tries = 3 ;
477496 let error_msg = "test" . to_owned ( ) ;
478- let messages = FailingConnector ( error_msg. clone ( ) ) ;
497+ let connector = FailingConnector ( error_msg. clone ( ) ) ;
479498 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
480- let stream = WsFlashBlockStream :: with_connector ( ws_url, messages ) ;
499+ let stream = WsFlashBlockStream :: with_connector ( ws_url, connector ) ;
481500
482501 let actual_errors: Vec < _ > =
483502 stream. take ( tries) . map ( Result :: unwrap_err) . map ( |e| format ! ( "{e}" ) ) . collect ( ) . await ;
@@ -490,7 +509,8 @@ mod tests {
490509 async fn test_stream_pongs_ping ( ) {
491510 const ECHO : [ u8 ; 3 ] = [ 1u8 , 2 , 3 ] ;
492511
493- let messages = [ Ok ( Message :: Ping ( Bytes :: from_static ( & ECHO ) ) ) ] ;
512+ let flashblock = flashblock ( ) ;
513+ let messages = [ Ok ( Message :: Ping ( Bytes :: from_static ( & ECHO ) ) ) , to_json_message ( & flashblock) ] ;
494514 let connector = FakeConnectorWithSink :: from ( messages) ;
495515 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
496516 let mut stream = WsFlashBlockStream :: with_connector ( ws_url, connector) ;
0 commit comments