Skip to content

Commit a8fd127

Browse files
authored
feat: Request Cancellation unary request support (#3004)
Signed-off-by: Jacky <[email protected]>
1 parent 10bfb73 commit a8fd127

File tree

11 files changed

+190
-131
lines changed

11 files changed

+190
-131
lines changed

examples/custom_backend/cancellation/middle_server.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,12 @@ async def generate(self, request, context):
3939
stream = await self.backend_client.generate(request, context=context)
4040

4141
# Stream responses back to client
42-
try:
43-
async for response in stream:
44-
data = response.data()
45-
print(f"Middle server: Forwarding response {data}")
46-
yield data
47-
48-
except ValueError as e:
49-
if str(e) != "Stream ended before generation completed":
50-
raise
51-
print("Middle server: Backend stream ended early due to cancellation")
42+
async for response in stream:
43+
data = response.data()
44+
print(f"Middle server: Forwarding response {data}")
45+
yield data
46+
47+
print("Middle server: Backend stream ended")
5248

5349

5450
async def main():

lib/bindings/python/tests/test_cancellation/test_client_context_cancel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,17 @@ async def test_client_context_cancel(server, client):
3838
if iteration_count >= 2:
3939
print("Cancelling after 2 responses...")
4040
context.stop_generating()
41-
break
4241

4342
iteration_count += 1
4443

44+
# Verify we received exactly 3 responses (0, 1, 2)
45+
assert iteration_count == 3
46+
4547
# Give server a moment to process the cancellation
4648
await asyncio.sleep(0.2)
4749

4850
# Verify server detected the cancellation
4951
assert handler.context_is_stopped
50-
assert handler.context_is_killed
52+
assert not handler.context_is_killed
5153

5254
# TODO: Test with _generate_until_asyncio_cancelled server handler

lib/bindings/python/tests/test_cancellation/test_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ async def test_middle_server_cancellation(
139139
assert (
140140
"Client: Cancelling after 3 responses..." in client_output
141141
), f"Client output: {client_output}"
142+
assert (
143+
"Middle server: Forwarding response 2" in middle_output
144+
), f"Middle server output: {middle_output}"
142145
assert (
143146
"Server: Cancelled at iteration" in server_output
144147
), f"Server output: {server_output}"
145-
assert (
146-
"Middle server: Backend stream ended early due to cancellation" in middle_output
147-
), f"Middle server output: {middle_output}"

lib/llm/src/migration.rs

Lines changed: 44 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ use crate::{
1717

1818
use dynamo_runtime::{
1919
pipeline::{
20-
AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream,
21-
ServerStreamingEngine, SingleIn, async_trait,
20+
AsyncEngineContext, AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream,
21+
ServerStreamingEngine, SingleIn, async_trait, network::STREAM_ERR_MSG,
2222
},
2323
protocols::{annotated::Annotated, maybe_error::MaybeError},
2424
};
@@ -55,30 +55,23 @@ impl
5555
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
5656
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
5757
let (preprocessed_request, context) = request.transfer(());
58-
let context_id = context.id().to_string();
5958
let engine_ctx = context.context();
6059
let engine_ctx_ = engine_ctx.clone();
6160
let retry_manager =
62-
RetryManager::build(context_id, preprocessed_request, next, self.migration_limit)
61+
RetryManager::build(engine_ctx, preprocessed_request, next, self.migration_limit)
6362
.await?;
64-
let response_stream = stream::unfold(retry_manager, move |mut retry_manager| {
65-
let engine_ctx = engine_ctx_.clone();
66-
async move {
67-
if engine_ctx.is_stopped() || engine_ctx.is_killed() {
68-
return None; // Stop if the context is cancelled or stopped
69-
}
70-
retry_manager
71-
.next()
72-
.await
73-
.map(|response| (response, retry_manager))
74-
}
63+
let response_stream = stream::unfold(retry_manager, move |mut retry_manager| async move {
64+
retry_manager
65+
.next()
66+
.await
67+
.map(|response| (response, retry_manager))
7568
});
76-
Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx))
69+
Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx_))
7770
}
7871
}
7972

8073
struct RetryManager {
81-
context_id: String,
74+
context: Arc<dyn AsyncEngineContext>,
8275
request: PreprocessedRequest,
8376
next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
8477
next_stream: Option<ManyOut<Annotated<LLMEngineOutput>>>,
@@ -87,13 +80,13 @@ struct RetryManager {
8780

8881
impl RetryManager {
8982
pub async fn build(
90-
context_id: String,
83+
context: Arc<dyn AsyncEngineContext>,
9184
preprocessed_request: PreprocessedRequest,
9285
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
9386
retries_left: u32,
9487
) -> Result<Self> {
9588
let mut slf = Self {
96-
context_id,
89+
context,
9790
request: preprocessed_request,
9891
next_generate: next,
9992
next_stream: None,
@@ -115,18 +108,16 @@ impl RetryManager {
115108
}
116109
};
117110
if let Some(response) = response_stream.next().await {
118-
if let Some(err) = response.err() {
119-
const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
120-
if err
111+
if let Some(err) = response.err()
112+
&& err
121113
.chain()
122114
.any(|e| e.to_string().starts_with(STREAM_ERR_MSG))
123-
{
124-
tracing::warn!("Stream disconnected... recreating stream...");
125-
if let Err(err) = self.new_stream().await {
126-
tracing::warn!("Cannot recreate stream: {:#}", err);
127-
} else {
128-
continue;
129-
}
115+
{
116+
tracing::warn!("Stream disconnected... recreating stream...");
117+
if let Err(err) = self.new_stream().await {
118+
tracing::warn!("Cannot recreate stream: {:#}", err);
119+
} else {
120+
continue;
130121
}
131122
}
132123
self.track_response(&response);
@@ -140,7 +131,8 @@ impl RetryManager {
140131
let mut response_stream: Option<Result<ManyOut<Annotated<LLMEngineOutput>>>> = None;
141132
while self.retries_left > 0 {
142133
self.retries_left -= 1;
143-
let request = Context::with_id(self.request.clone(), self.context_id.clone());
134+
let request = Context::with_id(self.request.clone(), self.context.id().to_string());
135+
self.context.link_child(request.context());
144136
response_stream = Some(self.next_generate.generate(request).await);
145137
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err()
146138
&& let Some(req_err) = err.downcast_ref::<NatsRequestError>()
@@ -339,10 +331,8 @@ mod tests {
339331
}
340332
}
341333
// Send the specific error that triggers retry logic
342-
let error_response = Annotated::from_err(
343-
anyhow::Error::msg("Stream ended before generation completed")
344-
.into(),
345-
);
334+
let error_response =
335+
Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
346336
let _ = tx.send(error_response).await;
347337
});
348338
} else {
@@ -381,10 +371,8 @@ mod tests {
381371
}
382372
}
383373
// Send the specific error that triggers retry logic
384-
let error_response = Annotated::from_err(
385-
anyhow::Error::msg("Stream ended before generation completed")
386-
.into(),
387-
);
374+
let error_response =
375+
Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
388376
let _ = tx.send(error_response).await;
389377
});
390378

@@ -417,10 +405,8 @@ mod tests {
417405
}
418406
}
419407
// Send the specific error that triggers retry logic
420-
let error_response = Annotated::from_err(
421-
anyhow::Error::msg("Stream ended before generation completed")
422-
.into(),
423-
);
408+
let error_response =
409+
Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
424410
let _ = tx.send(error_response).await;
425411
});
426412

@@ -434,10 +420,8 @@ mod tests {
434420
// Subsequent calls - immediately send stream error (no successful responses)
435421
tokio::spawn(async move {
436422
// Send the stream error immediately
437-
let error_response = Annotated::from_err(
438-
anyhow::Error::msg("Stream ended before generation completed")
439-
.into(),
440-
);
423+
let error_response =
424+
Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
441425
let _ = tx.send(error_response).await;
442426
});
443427

@@ -503,7 +487,8 @@ mod tests {
503487
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
504488
mock_engine;
505489

506-
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 0)
490+
let ctx = Arc::new(Controller::new(context_id.clone()));
491+
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 0)
507492
.await
508493
.expect("Failed to build RetryManager");
509494

@@ -541,7 +526,8 @@ mod tests {
541526
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
542527
mock_engine;
543528

544-
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3)
529+
let ctx = Arc::new(Controller::new(context_id.clone()));
530+
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3)
545531
.await
546532
.expect("Failed to build RetryManager");
547533

@@ -580,7 +566,8 @@ mod tests {
580566
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
581567
mock_engine;
582568

583-
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3)
569+
let ctx = Arc::new(Controller::new(context_id.clone()));
570+
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3)
584571
.await
585572
.expect("Failed to build RetryManager");
586573

@@ -620,7 +607,8 @@ mod tests {
620607
mock_engine;
621608

622609
// Should fail to build due to initial stream creation failure after exhausting all 3 retries
623-
let retry_manager_result = RetryManager::build(context_id, request, next_generate, 3).await;
610+
let ctx = Arc::new(Controller::new(context_id.clone()));
611+
let retry_manager_result = RetryManager::build(ctx, request, next_generate, 3).await;
624612

625613
assert!(retry_manager_result.is_err());
626614
if let Err(error) = retry_manager_result {
@@ -646,7 +634,8 @@ mod tests {
646634
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
647635
mock_engine;
648636

649-
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) // 3 retries
637+
let ctx = Arc::new(Controller::new(context_id.clone()));
638+
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3) // 3 retries
650639
.await
651640
.expect("Failed to build RetryManager");
652641

@@ -672,11 +661,7 @@ mod tests {
672661
let error_response = &responses[3];
673662
assert!(error_response.err().is_some());
674663
if let Some(error) = error_response.err() {
675-
assert!(
676-
error
677-
.to_string()
678-
.contains("Stream ended before generation completed")
679-
);
664+
assert!(error.to_string().contains(STREAM_ERR_MSG));
680665
}
681666
}
682667

@@ -698,7 +683,8 @@ mod tests {
698683
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
699684
mock_engine;
700685

701-
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) // 3 retries
686+
let ctx = Arc::new(Controller::new(context_id.clone()));
687+
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3) // 3 retries
702688
.await
703689
.expect("Failed to build RetryManager");
704690

@@ -724,11 +710,7 @@ mod tests {
724710
let error_response = &responses[3];
725711
assert!(error_response.err().is_some());
726712
if let Some(error) = error_response.err() {
727-
assert!(
728-
error
729-
.to_string()
730-
.contains("Stream ended before generation completed")
731-
);
713+
assert!(error.to_string().contains(STREAM_ERR_MSG));
732714
}
733715
}
734716
}

lib/runtime/src/pipeline/context.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -358,18 +358,20 @@ impl AsyncEngineContext for Controller {
358358

359359
async fn stopped(&self) {
360360
let mut rx = self.rx.clone();
361-
if *rx.borrow_and_update() != State::Live {
362-
return;
361+
loop {
362+
if *rx.borrow_and_update() != State::Live || rx.changed().await.is_err() {
363+
return;
364+
}
363365
}
364-
let _ = rx.changed().await;
365366
}
366367

367368
async fn killed(&self) {
368369
let mut rx = self.rx.clone();
369-
if *rx.borrow_and_update() == State::Killed {
370-
return;
370+
loop {
371+
if *rx.borrow_and_update() == State::Killed || rx.changed().await.is_err() {
372+
return;
373+
}
371374
}
372-
let _ = rx.changed().await;
373375
}
374376

375377
fn stop_generating(&self) {

lib/runtime/src/pipeline/network.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ use super::{
2727
};
2828
use ingress::push_handler::WorkHandlerMetrics;
2929

30+
// Define stream error message constant
31+
pub const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
32+
3033
// Add Prometheus metrics types
3134
use crate::metrics::MetricsRegistry;
3235
use prometheus::{CounterVec, Histogram, IntCounter, IntCounterVec, IntGauge};

lib/runtime/src/pipeline/network/egress/addressed_router.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ where
8080
let (addressed_request, context) = request.transfer(());
8181
let (request, address) = addressed_request.into_parts();
8282
let engine_ctx = context.context();
83+
let engine_ctx_ = engine_ctx.clone();
8384

8485
// registration options for the data plane in a singe in / many out configuration
8586
let options = StreamOptions::builder()
@@ -209,11 +210,18 @@ where
209210
}
210211
}
211212
} else if is_complete_final {
213+
// end of stream
214+
None
215+
} else if engine_ctx_.is_stopped() {
216+
// Gracefully end the stream if 'stop_generating()' was called. Do NOT check for
217+
// 'is_killed()' here because it implies the stream ended abnormally which should be
218+
// handled by the error branch below.
219+
log::debug!("Request cancelled and then trying to read a response");
212220
None
213221
} else {
214-
Some(U::from_err(
215-
Error::msg("Stream ended before generation completed").into(),
216-
))
222+
// stream ended unexpectedly
223+
log::debug!("{STREAM_ERR_MSG}");
224+
Some(U::from_err(Error::msg(STREAM_ERR_MSG).into()))
217225
}
218226
});
219227

0 commit comments

Comments
 (0)