|
11 | 11 | from typing import Dict, List, Optional, Union |
12 | 12 |
|
13 | 13 | import torch |
| 14 | +from cuda import cudart |
14 | 15 |
|
15 | 16 | from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType |
16 | 17 | from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager |
|
25 | 26 | from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType, |
26 | 27 | ReqIdsSet) |
27 | 28 | from tensorrt_llm.logger import logger |
| 29 | +from tensorrt_llm.runtime.generation import CUASSERT |
28 | 30 |
|
29 | 31 | from ..distributed import Distributed |
30 | 32 | from ..speculative.drafter import Drafter |
@@ -644,6 +646,8 @@ def _need_return_log_probs(self, scheduled_requests: ScheduledRequests): |
644 | 646 | def _executor_loop_pp(self): |
645 | 647 | logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}") |
646 | 648 | torch.cuda.set_device(self.device_id) |
| 649 | + # ensure the context is created, otherwise, some MPI calls will fail. |
| 650 | + CUASSERT(cudart.cudaSetDevice(self.device_id)) |
647 | 651 | microbatch_id = 0 |
648 | 652 | with self._profiler() as profile_step: |
649 | 653 | iter_start_time = time.time() |
@@ -897,6 +901,8 @@ def _execute_guided_decoder(self, scheduled_batch, logits): |
897 | 901 |
|
898 | 902 | def _executor_loop(self): |
899 | 903 | torch.cuda.set_device(self.device_id) |
| 904 | + # ensure the context is created, otherwise, some MPI calls will fail. |
| 905 | + CUASSERT(cudart.cudaSetDevice(self.device_id)) |
900 | 906 | with self._profiler() as profile_step: |
901 | 907 | sample_state = None |
902 | 908 | iter_start_time = time.time() |
@@ -1014,6 +1020,8 @@ def _prepare_draft_requests(self): |
1014 | 1020 |
|
1015 | 1021 | def _executor_loop_overlap(self): |
1016 | 1022 | torch.cuda.set_device(self.device_id) |
| 1023 | + # ensure the context is created, otherwise, some MPI calls will fail. |
| 1024 | + CUASSERT(cudart.cudaSetDevice(self.device_id)) |
1017 | 1025 | if self.dist.rank == 0 and not self.is_warmup and self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver: |
1018 | 1026 | while self.executor_request_queue.get_request_queue_size( |
1019 | 1027 | ) < self.benchmark_req_queues_size: |
|
0 commit comments