7272 fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
7373 let this = self . get_mut ( ) ;
7474
75- loop {
75+ ' start : loop {
7676 if this. state == State :: Initial {
7777 this. connect ( ) ;
7878 }
@@ -104,7 +104,9 @@ where
104104 . expect( "Stream state should be unreachable without stream" )
105105 . poll_next_unpin( cx) )
106106 else {
107- return Poll :: Ready ( None ) ;
107+ this. state = State :: Initial ;
108+
109+ continue ' start;
108110 } ;
109111
110112 match msg {
@@ -241,6 +243,14 @@ mod tests {
241243 #[ derive( Default ) ]
242244 struct FakeStream ( Vec < Result < Message , Error > > ) ;
243245
246+ impl FakeStream {
247+ fn new ( mut messages : Vec < Result < Message , Error > > ) -> Self {
248+ messages. reverse ( ) ;
249+
250+ Self ( messages)
251+ }
252+ }
253+
244254 impl Clone for FakeStream {
245255 fn clone ( & self ) -> Self {
246256 Self (
@@ -350,7 +360,7 @@ mod tests {
350360
351361 impl < T : IntoIterator < Item = Result < Message , Error > > > From < T > for FakeConnector {
352362 fn from ( value : T ) -> Self {
353- Self ( FakeStream ( value. into_iter ( ) . collect ( ) ) )
363+ Self ( FakeStream :: new ( value. into_iter ( ) . collect ( ) ) )
354364 }
355365 }
356366
@@ -368,7 +378,7 @@ mod tests {
368378
369379 impl < T : IntoIterator < Item = Result < Message , Error > > > From < T > for FakeConnectorWithSink {
370380 fn from ( value : T ) -> Self {
371- Self ( FakeStream ( value. into_iter ( ) . collect ( ) ) )
381+ Self ( FakeStream :: new ( value. into_iter ( ) . collect ( ) ) )
372382 }
373383 }
374384
@@ -404,13 +414,8 @@ mod tests {
404414 Ok ( Message :: Binary ( Bytes :: from ( compressed) ) )
405415 }
406416
407- #[ test_case:: test_case( to_json_message; "json" ) ]
408- #[ test_case:: test_case( to_brotli_message; "brotli" ) ]
409- #[ tokio:: test]
410- async fn test_stream_decodes_messages_successfully (
411- to_message : impl Fn ( & FlashBlock ) -> Result < Message , Error > ,
412- ) {
413- let flashblocks = [ FlashBlock {
417+ fn flashblock ( ) -> FlashBlock {
418+ FlashBlock {
414419 payload_id : Default :: default ( ) ,
415420 index : 0 ,
416421 base : Some ( ExecutionPayloadBaseV1 {
@@ -426,34 +431,51 @@ mod tests {
426431 } ) ,
427432 diff : Default :: default ( ) ,
428433 metadata : Default :: default ( ) ,
429- } ] ;
434+ }
435+ }
430436
431- let messages = FakeConnector :: from ( flashblocks. iter ( ) . map ( to_message) ) ;
437+ #[ test_case:: test_case( to_json_message; "json" ) ]
438+ #[ test_case:: test_case( to_brotli_message; "brotli" ) ]
439+ #[ tokio:: test]
440+ async fn test_stream_decodes_messages_successfully (
441+ to_message : impl Fn ( & FlashBlock ) -> Result < Message , Error > ,
442+ ) {
443+ let flashblocks = [ flashblock ( ) ] ;
444+ let connector = FakeConnector :: from ( flashblocks. iter ( ) . map ( to_message) ) ;
432445 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
433- let stream = WsFlashBlockStream :: with_connector ( ws_url, messages ) ;
446+ let stream = WsFlashBlockStream :: with_connector ( ws_url, connector ) ;
434447
435- let actual_messages: Vec < _ > = stream. map ( Result :: unwrap) . collect ( ) . await ;
448+ let actual_messages: Vec < _ > = stream. take ( 1 ) . map ( Result :: unwrap) . collect ( ) . await ;
436449 let expected_messages = flashblocks. to_vec ( ) ;
437450
438451 assert_eq ! ( actual_messages, expected_messages) ;
439452 }
440453
441454 #[ tokio:: test]
442455 async fn test_stream_ignores_non_binary_message ( ) {
443- let messages = FakeConnector :: from ( [ Ok ( Message :: Text ( Utf8Bytes :: from ( "test" ) ) ) ] ) ;
456+ let flashblock = flashblock ( ) ;
457+ let connector = FakeConnector :: from ( [
458+ Ok ( Message :: Text ( Utf8Bytes :: from ( "test" ) ) ) ,
459+ to_json_message ( & flashblock) ,
460+ ] ) ;
444461 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
445- let mut stream = WsFlashBlockStream :: with_connector ( ws_url, messages) ;
446- assert ! ( stream. next( ) . await . is_none( ) ) ;
462+ let mut stream = WsFlashBlockStream :: with_connector ( ws_url, connector) ;
463+
464+ let expected_message = flashblock;
465+ let actual_message =
466+ stream. next ( ) . await . expect ( "Binary message should not be ignored" ) . unwrap ( ) ;
467+
468+ assert_eq ! ( actual_message, expected_message)
447469 }
448470
449471 #[ tokio:: test]
450472 async fn test_stream_passes_errors_through ( ) {
451- let messages = FakeConnector :: from ( [ Err ( Error :: AttackAttempt ) ] ) ;
473+ let connector = FakeConnector :: from ( [ Err ( Error :: AttackAttempt ) ] ) ;
452474 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
453- let stream = WsFlashBlockStream :: with_connector ( ws_url, messages ) ;
475+ let stream = WsFlashBlockStream :: with_connector ( ws_url, connector ) ;
454476
455477 let actual_messages: Vec < _ > =
456- stream. map ( Result :: unwrap_err) . map ( |e| format ! ( "{e}" ) ) . collect ( ) . await ;
478+ stream. take ( 1 ) . map ( Result :: unwrap_err) . map ( |e| format ! ( "{e}" ) ) . collect ( ) . await ;
457479 let expected_messages = vec ! [ "Attack attempt detected" . to_owned( ) ] ;
458480
459481 assert_eq ! ( actual_messages, expected_messages) ;
@@ -463,9 +485,9 @@ mod tests {
463485 async fn test_connect_error_causes_retries ( ) {
464486 let tries = 3 ;
465487 let error_msg = "test" . to_owned ( ) ;
466- let messages = FailingConnector ( error_msg. clone ( ) ) ;
488+ let connector = FailingConnector ( error_msg. clone ( ) ) ;
467489 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
468- let stream = WsFlashBlockStream :: with_connector ( ws_url, messages ) ;
490+ let stream = WsFlashBlockStream :: with_connector ( ws_url, connector ) ;
469491
470492 let actual_errors: Vec < _ > =
471493 stream. take ( tries) . map ( Result :: unwrap_err) . map ( |e| format ! ( "{e}" ) ) . collect ( ) . await ;
@@ -478,7 +500,8 @@ mod tests {
478500 async fn test_stream_pongs_ping ( ) {
479501 const ECHO : [ u8 ; 3 ] = [ 1u8 , 2 , 3 ] ;
480502
481- let messages = [ Ok ( Message :: Ping ( Bytes :: from_static ( & ECHO ) ) ) ] ;
503+ let flashblock = flashblock ( ) ;
504+ let messages = [ Ok ( Message :: Ping ( Bytes :: from_static ( & ECHO ) ) ) , to_json_message ( & flashblock) ] ;
482505 let connector = FakeConnectorWithSink :: from ( messages) ;
483506 let ws_url = "http://localhost" . parse ( ) . unwrap ( ) ;
484507 let mut stream = WsFlashBlockStream :: with_connector ( ws_url, connector) ;
0 commit comments