Skip to content

Commit e4cafa1

Browse files
Superjomndominicshanshan
authored andcommitted
[https://nvbugs/5383702][fix] error propagation in GenerationExecutor (#6793)
Signed-off-by: Superjomn <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent 3497837 commit e4cafa1

File tree

3 files changed

+42
-3
lines changed

3 files changed

+42
-3
lines changed

tensorrt_llm/executor/proxy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,14 +317,15 @@ def mpi_done_callback(future: concurrent.futures.Future):
317317

318318
while True:
319319
if self.worker_init_status_queue.poll(1):
320-
ready_signal = self.worker_init_status_queue.get()
320+
ready_signal, error_trace = self.worker_init_status_queue.get()
321321
break
322322
if any(fut.done() for fut in self.mpi_futures):
323323
logger.error("Executor worker died during initialization.")
324324
raise RuntimeError("Executor worker died during initialization")
325325
self._handle_background_error()
326326

327327
if ready_signal != GenerationExecutorProxy.READY_SIGNAL:
328+
logger.error(f"Executor worker initialization error: {error_trace}")
328329
self.mpi_session.shutdown_abort(reason=ready_signal)
329330
raise RuntimeError(
330331
"Executor worker returned error") from ready_signal

tensorrt_llm/executor/worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,7 @@ def notify_proxy_threads_to_quit():
774774
logger.error(traceback.format_exc())
775775
print_colored_debug(f"error: {traceback.format_exc()}", "red")
776776
if is_leader:
777-
worker_init_status_queue.put(e)
777+
worker_init_status_queue.put((e, traceback.format_exc()))
778778
return
779779

780780
with worker:
@@ -792,7 +792,7 @@ def notify_proxy_threads_to_quit():
792792
mp_stats_queue)
793793
worker._set_iteration_result_queue(worker.kv_events_queues,
794794
kv_cache_events_queue)
795-
worker_init_status_queue.put(ready_signal)
795+
worker_init_status_queue.put((ready_signal, None))
796796
while (req := request_queue.get()) is not None:
797797
if isinstance(req, CancellingRequest):
798798
worker.abort_request(req.id)

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
from tensorrt_llm import LLM
7+
from tensorrt_llm.executor import GenerationExecutorWorker
78
from tensorrt_llm.llmapi import KvCacheConfig
89
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
910
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
@@ -818,3 +819,40 @@ def test_max_num_token_check(self):
818819
match="should not exceed max_num_tokens"):
819820
ids = [random.randint(10, 100) for _ in range(101)]
820821
llm.generate([ids])
822+
823+
824+
class FailingExecutorWorker(GenerationExecutorWorker):
825+
"""Mock worker that fails during initialization to test error handling."""
826+
827+
def __init__(self, *args, **kwargs):
828+
# Simulate a constructor failure
829+
raise RuntimeError(
830+
"Mock GenerationExecutorWorker initialization failed")
831+
832+
833+
FailingExecutor = type(
834+
"FailingExecutor", (), {
835+
"create":
836+
classmethod(
837+
lambda cls, *args, **kwargs: FailingExecutorWorker(*args, **kwargs))
838+
})
839+
840+
841+
def test_llm_with_proxy_error():
842+
"""Test that LLM properly handles GenerationExecutorWorker constructor failures.
843+
844+
This test mocks the GenerationExecutorWorker to fail during __init__ and
845+
verifies that the LLM class properly catches and re-raises the error.
846+
"""
847+
from unittest.mock import patch
848+
849+
# Test that the error is properly caught and re-raised by LLM
850+
# We patch GenerationExecutor.create directly to return our failing worker
851+
with patch('tensorrt_llm.executor.executor.GenerationExecutor.create',
852+
side_effect=lambda *args, **kwargs: FailingExecutorWorker(
853+
*args, **kwargs)):
854+
with pytest.raises(
855+
RuntimeError,
856+
match="Mock GenerationExecutorWorker initialization failed"):
857+
llm = LLM(model=llama_model_path,
858+
kv_cache_config=global_kvcache_config)

0 commit comments

Comments
 (0)