|
11 | 11 | from typing import Dict, List, Optional, Union
|
12 | 12 |
|
13 | 13 | import torch
|
| 14 | +from cuda import cudart |
14 | 15 |
|
15 | 16 | from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
|
16 | 17 | from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
|
|
25 | 26 | from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
|
26 | 27 | ReqIdsSet)
|
27 | 28 | from tensorrt_llm.logger import logger
|
| 29 | +from tensorrt_llm.runtime.generation import CUASSERT |
28 | 30 |
|
29 | 31 | from ..distributed import Distributed
|
30 | 32 | from ..speculative.drafter import Drafter
|
@@ -644,6 +646,8 @@ def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):
|
644 | 646 | def _executor_loop_pp(self):
|
645 | 647 | logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
|
646 | 648 | torch.cuda.set_device(self.device_id)
|
| 649 | + # ensure the context is created, otherwise, some MPI calls will fail. |
| 650 | + CUASSERT(cudart.cudaSetDevice(self.device_id)) |
647 | 651 | microbatch_id = 0
|
648 | 652 | with self._profiler() as profile_step:
|
649 | 653 | iter_start_time = time.time()
|
@@ -897,6 +901,8 @@ def _execute_guided_decoder(self, scheduled_batch, logits):
|
897 | 901 |
|
898 | 902 | def _executor_loop(self):
|
899 | 903 | torch.cuda.set_device(self.device_id)
|
| 904 | + # ensure the context is created, otherwise, some MPI calls will fail. |
| 905 | + CUASSERT(cudart.cudaSetDevice(self.device_id)) |
900 | 906 | with self._profiler() as profile_step:
|
901 | 907 | sample_state = None
|
902 | 908 | iter_start_time = time.time()
|
@@ -992,6 +998,8 @@ def _prepare_draft_requests(self):
|
992 | 998 |
|
993 | 999 | def _executor_loop_overlap(self):
|
994 | 1000 | torch.cuda.set_device(self.device_id)
|
| 1001 | + # ensure the context is created, otherwise, some MPI calls will fail. |
| 1002 | + CUASSERT(cudart.cudaSetDevice(self.device_id)) |
995 | 1003 | if self.dist.rank == 0 and not self.is_warmup and self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver:
|
996 | 1004 | while self.executor_request_queue.get_request_queue_size(
|
997 | 1005 | ) < self.benchmark_req_queues_size:
|
|
0 commit comments