|
11 | 11 | from typing import Dict, List, Optional, Union
|
12 | 12 |
|
13 | 13 | import torch
|
| 14 | +from cuda import cudart |
14 | 15 |
|
15 | 16 | from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
|
16 | 17 | from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
|
|
25 | 26 | from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
|
26 | 27 | ReqIdsSet)
|
27 | 28 | from tensorrt_llm.logger import logger
|
| 29 | +from tensorrt_llm.runtime.generation import CUASSERT |
28 | 30 |
|
29 | 31 | from ..distributed import Distributed
|
30 | 32 | from ..speculative.drafter import Drafter
|
@@ -634,6 +636,8 @@ def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):
|
634 | 636 | def _executor_loop_pp(self):
|
635 | 637 | logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
|
636 | 638 | torch.cuda.set_device(self.device_id)
|
| 639 | + # ensure the context is created, otherwise, some MPI calls will fail. |
| 640 | + CUASSERT(cudart.cudaSetDevice(self.device_id)) |
637 | 641 | microbatch_id = 0
|
638 | 642 | with self._profiler() as profile_step:
|
639 | 643 | iter_start_time = time.time()
|
@@ -887,6 +891,8 @@ def _execute_guided_decoder(self, scheduled_batch, logits):
|
887 | 891 |
|
888 | 892 | def _executor_loop(self):
|
889 | 893 | torch.cuda.set_device(self.device_id)
|
| 894 | + # ensure the context is created, otherwise, some MPI calls will fail. |
| 895 | + CUASSERT(cudart.cudaSetDevice(self.device_id)) |
890 | 896 | with self._profiler() as profile_step:
|
891 | 897 | sample_state = None
|
892 | 898 | iter_start_time = time.time()
|
@@ -982,6 +988,8 @@ def _prepare_draft_requests(self):
|
982 | 988 |
|
983 | 989 | def _executor_loop_overlap(self):
|
984 | 990 | torch.cuda.set_device(self.device_id)
|
| 991 | + # ensure the context is created, otherwise, some MPI calls will fail. |
| 992 | + CUASSERT(cudart.cudaSetDevice(self.device_id)) |
985 | 993 | if self.dist.rank == 0 and not self.is_warmup and self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver:
|
986 | 994 | while self.executor_request_queue.get_request_queue_size(
|
987 | 995 | ) < self.benchmark_req_queues_size:
|
|
0 commit comments