@@ -17,8 +17,8 @@ use crate::{
17
17
18
18
use dynamo_runtime:: {
19
19
pipeline:: {
20
- AsyncEngineContextProvider , Context , ManyOut , Operator , ResponseStream ,
21
- ServerStreamingEngine , SingleIn , async_trait,
20
+ AsyncEngineContext , AsyncEngineContextProvider , Context , ManyOut , Operator , ResponseStream ,
21
+ ServerStreamingEngine , SingleIn , async_trait, network :: STREAM_ERR_MSG ,
22
22
} ,
23
23
protocols:: { annotated:: Annotated , maybe_error:: MaybeError } ,
24
24
} ;
55
55
next : ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > ,
56
56
) -> Result < ManyOut < Annotated < LLMEngineOutput > > > {
57
57
let ( preprocessed_request, context) = request. transfer ( ( ) ) ;
58
- let context_id = context. id ( ) . to_string ( ) ;
59
58
let engine_ctx = context. context ( ) ;
60
59
let engine_ctx_ = engine_ctx. clone ( ) ;
61
60
let retry_manager =
62
- RetryManager :: build ( context_id , preprocessed_request, next, self . migration_limit )
61
+ RetryManager :: build ( engine_ctx , preprocessed_request, next, self . migration_limit )
63
62
. await ?;
64
- let response_stream = stream:: unfold ( retry_manager, move |mut retry_manager| {
65
- let engine_ctx = engine_ctx_. clone ( ) ;
66
- async move {
67
- if engine_ctx. is_stopped ( ) || engine_ctx. is_killed ( ) {
68
- return None ; // Stop if the context is cancelled or stopped
69
- }
70
- retry_manager
71
- . next ( )
72
- . await
73
- . map ( |response| ( response, retry_manager) )
74
- }
63
+ let response_stream = stream:: unfold ( retry_manager, move |mut retry_manager| async move {
64
+ retry_manager
65
+ . next ( )
66
+ . await
67
+ . map ( |response| ( response, retry_manager) )
75
68
} ) ;
76
- Ok ( ResponseStream :: new ( Box :: pin ( response_stream) , engine_ctx ) )
69
+ Ok ( ResponseStream :: new ( Box :: pin ( response_stream) , engine_ctx_ ) )
77
70
}
78
71
}
79
72
80
73
struct RetryManager {
81
- context_id : String ,
74
+ context : Arc < dyn AsyncEngineContext > ,
82
75
request : PreprocessedRequest ,
83
76
next_generate : ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > ,
84
77
next_stream : Option < ManyOut < Annotated < LLMEngineOutput > > > ,
@@ -87,13 +80,13 @@ struct RetryManager {
87
80
88
81
impl RetryManager {
89
82
pub async fn build (
90
- context_id : String ,
83
+ context : Arc < dyn AsyncEngineContext > ,
91
84
preprocessed_request : PreprocessedRequest ,
92
85
next : ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > ,
93
86
retries_left : u32 ,
94
87
) -> Result < Self > {
95
88
let mut slf = Self {
96
- context_id ,
89
+ context ,
97
90
request : preprocessed_request,
98
91
next_generate : next,
99
92
next_stream : None ,
@@ -115,18 +108,16 @@ impl RetryManager {
115
108
}
116
109
} ;
117
110
if let Some ( response) = response_stream. next ( ) . await {
118
- if let Some ( err) = response. err ( ) {
119
- const STREAM_ERR_MSG : & str = "Stream ended before generation completed" ;
120
- if err
111
+ if let Some ( err) = response. err ( )
112
+ && err
121
113
. chain ( )
122
114
. any ( |e| e. to_string ( ) . starts_with ( STREAM_ERR_MSG ) )
123
- {
124
- tracing:: warn!( "Stream disconnected... recreating stream..." ) ;
125
- if let Err ( err) = self . new_stream ( ) . await {
126
- tracing:: warn!( "Cannot recreate stream: {:#}" , err) ;
127
- } else {
128
- continue ;
129
- }
115
+ {
116
+ tracing:: warn!( "Stream disconnected... recreating stream..." ) ;
117
+ if let Err ( err) = self . new_stream ( ) . await {
118
+ tracing:: warn!( "Cannot recreate stream: {:#}" , err) ;
119
+ } else {
120
+ continue ;
130
121
}
131
122
}
132
123
self . track_response ( & response) ;
@@ -140,7 +131,8 @@ impl RetryManager {
140
131
let mut response_stream: Option < Result < ManyOut < Annotated < LLMEngineOutput > > > > = None ;
141
132
while self . retries_left > 0 {
142
133
self . retries_left -= 1 ;
143
- let request = Context :: with_id ( self . request . clone ( ) , self . context_id . clone ( ) ) ;
134
+ let request = Context :: with_id ( self . request . clone ( ) , self . context . id ( ) . to_string ( ) ) ;
135
+ self . context . link_child ( request. context ( ) ) ;
144
136
response_stream = Some ( self . next_generate . generate ( request) . await ) ;
145
137
if let Some ( err) = response_stream. as_ref ( ) . unwrap ( ) . as_ref ( ) . err ( )
146
138
&& let Some ( req_err) = err. downcast_ref :: < NatsRequestError > ( )
@@ -339,10 +331,8 @@ mod tests {
339
331
}
340
332
}
341
333
// Send the specific error that triggers retry logic
342
- let error_response = Annotated :: from_err (
343
- anyhow:: Error :: msg ( "Stream ended before generation completed" )
344
- . into ( ) ,
345
- ) ;
334
+ let error_response =
335
+ Annotated :: from_err ( anyhow:: Error :: msg ( STREAM_ERR_MSG ) . into ( ) ) ;
346
336
let _ = tx. send ( error_response) . await ;
347
337
} ) ;
348
338
} else {
@@ -381,10 +371,8 @@ mod tests {
381
371
}
382
372
}
383
373
// Send the specific error that triggers retry logic
384
- let error_response = Annotated :: from_err (
385
- anyhow:: Error :: msg ( "Stream ended before generation completed" )
386
- . into ( ) ,
387
- ) ;
374
+ let error_response =
375
+ Annotated :: from_err ( anyhow:: Error :: msg ( STREAM_ERR_MSG ) . into ( ) ) ;
388
376
let _ = tx. send ( error_response) . await ;
389
377
} ) ;
390
378
@@ -417,10 +405,8 @@ mod tests {
417
405
}
418
406
}
419
407
// Send the specific error that triggers retry logic
420
- let error_response = Annotated :: from_err (
421
- anyhow:: Error :: msg ( "Stream ended before generation completed" )
422
- . into ( ) ,
423
- ) ;
408
+ let error_response =
409
+ Annotated :: from_err ( anyhow:: Error :: msg ( STREAM_ERR_MSG ) . into ( ) ) ;
424
410
let _ = tx. send ( error_response) . await ;
425
411
} ) ;
426
412
@@ -434,10 +420,8 @@ mod tests {
434
420
// Subsequent calls - immediately send stream error (no successful responses)
435
421
tokio:: spawn ( async move {
436
422
// Send the stream error immediately
437
- let error_response = Annotated :: from_err (
438
- anyhow:: Error :: msg ( "Stream ended before generation completed" )
439
- . into ( ) ,
440
- ) ;
423
+ let error_response =
424
+ Annotated :: from_err ( anyhow:: Error :: msg ( STREAM_ERR_MSG ) . into ( ) ) ;
441
425
let _ = tx. send ( error_response) . await ;
442
426
} ) ;
443
427
@@ -503,7 +487,8 @@ mod tests {
503
487
let next_generate: ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > =
504
488
mock_engine;
505
489
506
- let mut retry_manager = RetryManager :: build ( context_id, request, next_generate, 0 )
490
+ let ctx = Arc :: new ( Controller :: new ( context_id. clone ( ) ) ) ;
491
+ let mut retry_manager = RetryManager :: build ( ctx, request, next_generate, 0 )
507
492
. await
508
493
. expect ( "Failed to build RetryManager" ) ;
509
494
@@ -541,7 +526,8 @@ mod tests {
541
526
let next_generate: ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > =
542
527
mock_engine;
543
528
544
- let mut retry_manager = RetryManager :: build ( context_id, request, next_generate, 3 )
529
+ let ctx = Arc :: new ( Controller :: new ( context_id. clone ( ) ) ) ;
530
+ let mut retry_manager = RetryManager :: build ( ctx, request, next_generate, 3 )
545
531
. await
546
532
. expect ( "Failed to build RetryManager" ) ;
547
533
@@ -580,7 +566,8 @@ mod tests {
580
566
let next_generate: ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > =
581
567
mock_engine;
582
568
583
- let mut retry_manager = RetryManager :: build ( context_id, request, next_generate, 3 )
569
+ let ctx = Arc :: new ( Controller :: new ( context_id. clone ( ) ) ) ;
570
+ let mut retry_manager = RetryManager :: build ( ctx, request, next_generate, 3 )
584
571
. await
585
572
. expect ( "Failed to build RetryManager" ) ;
586
573
@@ -620,7 +607,8 @@ mod tests {
620
607
mock_engine;
621
608
622
609
// Should fail to build due to initial stream creation failure after exhausting all 3 retries
623
- let retry_manager_result = RetryManager :: build ( context_id, request, next_generate, 3 ) . await ;
610
+ let ctx = Arc :: new ( Controller :: new ( context_id. clone ( ) ) ) ;
611
+ let retry_manager_result = RetryManager :: build ( ctx, request, next_generate, 3 ) . await ;
624
612
625
613
assert ! ( retry_manager_result. is_err( ) ) ;
626
614
if let Err ( error) = retry_manager_result {
@@ -646,7 +634,8 @@ mod tests {
646
634
let next_generate: ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > =
647
635
mock_engine;
648
636
649
- let mut retry_manager = RetryManager :: build ( context_id, request, next_generate, 3 ) // 3 retries
637
+ let ctx = Arc :: new ( Controller :: new ( context_id. clone ( ) ) ) ;
638
+ let mut retry_manager = RetryManager :: build ( ctx, request, next_generate, 3 ) // 3 retries
650
639
. await
651
640
. expect ( "Failed to build RetryManager" ) ;
652
641
@@ -672,11 +661,7 @@ mod tests {
672
661
let error_response = & responses[ 3 ] ;
673
662
assert ! ( error_response. err( ) . is_some( ) ) ;
674
663
if let Some ( error) = error_response. err ( ) {
675
- assert ! (
676
- error
677
- . to_string( )
678
- . contains( "Stream ended before generation completed" )
679
- ) ;
664
+ assert ! ( error. to_string( ) . contains( STREAM_ERR_MSG ) ) ;
680
665
}
681
666
}
682
667
@@ -698,7 +683,8 @@ mod tests {
698
683
let next_generate: ServerStreamingEngine < PreprocessedRequest , Annotated < LLMEngineOutput > > =
699
684
mock_engine;
700
685
701
- let mut retry_manager = RetryManager :: build ( context_id, request, next_generate, 3 ) // 3 retries
686
+ let ctx = Arc :: new ( Controller :: new ( context_id. clone ( ) ) ) ;
687
+ let mut retry_manager = RetryManager :: build ( ctx, request, next_generate, 3 ) // 3 retries
702
688
. await
703
689
. expect ( "Failed to build RetryManager" ) ;
704
690
@@ -724,11 +710,7 @@ mod tests {
724
710
let error_response = & responses[ 3 ] ;
725
711
assert ! ( error_response. err( ) . is_some( ) ) ;
726
712
if let Some ( error) = error_response. err ( ) {
727
- assert ! (
728
- error
729
- . to_string( )
730
- . contains( "Stream ended before generation completed" )
731
- ) ;
713
+ assert ! ( error. to_string( ) . contains( STREAM_ERR_MSG ) ) ;
732
714
}
733
715
}
734
716
}
0 commit comments