|
4 | 4 | import pytest
|
5 | 5 |
|
6 | 6 | from tensorrt_llm import LLM
|
| 7 | +from tensorrt_llm.executor import GenerationExecutorWorker |
7 | 8 | from tensorrt_llm.llmapi import KvCacheConfig
|
8 | 9 | from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
|
9 | 10 | from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
|
@@ -818,3 +819,40 @@ def test_max_num_token_check(self):
|
818 | 819 | match="should not exceed max_num_tokens"):
|
819 | 820 | ids = [random.randint(10, 100) for _ in range(101)]
|
820 | 821 | llm.generate([ids])
|
| 822 | + |
| 823 | + |
| 824 | +class FailingExecutorWorker(GenerationExecutorWorker): |
| 825 | + """Mock worker that fails during initialization to test error handling.""" |
| 826 | + |
| 827 | + def __init__(self, *args, **kwargs): |
| 828 | + # Simulate a constructor failure |
| 829 | + raise RuntimeError( |
| 830 | + "Mock GenerationExecutorWorker initialization failed") |
| 831 | + |
| 832 | + |
| 833 | +FailingExecutor = type( |
| 834 | + "FailingExecutor", (), { |
| 835 | + "create": |
| 836 | + classmethod( |
| 837 | + lambda cls, *args, **kwargs: FailingExecutorWorker(*args, **kwargs)) |
| 838 | + }) |
| 839 | + |
| 840 | + |
| 841 | +def test_llm_with_proxy_error(): |
| 842 | + """Test that LLM properly handles GenerationExecutorWorker constructor failures. |
| 843 | +
|
| 844 | + This test mocks the GenerationExecutorWorker to fail during __init__ and |
| 845 | + verifies that the LLM class properly catches and re-raises the error. |
| 846 | + """ |
| 847 | + from unittest.mock import patch |
| 848 | + |
| 849 | + # Test that the error is properly caught and re-raised by LLM |
| 850 | + # We patch GenerationExecutor.create directly to return our failing worker |
| 851 | + with patch('tensorrt_llm.executor.executor.GenerationExecutor.create', |
| 852 | + side_effect=lambda *args, **kwargs: FailingExecutorWorker( |
| 853 | + *args, **kwargs)): |
| 854 | + with pytest.raises( |
| 855 | + RuntimeError, |
| 856 | + match="Mock GenerationExecutorWorker initialization failed"): |
| 857 | + llm = LLM(model=llama_model_path, |
| 858 | + kv_cache_config=global_kvcache_config) |
0 commit comments