Skip to content

Commit 1767c8f

Browse files
chuangz0Ria Jain
authored andcommitted
use cudaSetDevice to create context ,fix nvbug 5394497 (NVIDIA#6403)
Signed-off-by: Chuang Zhu <[email protected]>
1 parent c54dc55 commit 1767c8f

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ void CacheFormatter::format(TransferSession& session)
347347
auto copyTargetSlice = runtime::ITensor::slice(preAllocSendBuffer, 0, sendSize);
348348
bufferManager.copy(*copySlice, *copyTargetSlice);
349349
bufferManager.getStream().synchronize();
350-
session.send(processIdx, copyTargetSlice->data(), sendSize);
350+
session.send(processIdx, copyTargetSlice->data(), copyTargetSlice->getSizeInBytes());
351351
remainSendSize -= sendSize;
352352
}
353353
}

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ void MLACacheFormatter::format(TransferSession& session)
223223
auto copyTargetSlice = runtime::ITensor::slice(preAllocSendBuffer, 0, sendSize);
224224
bufferManager.copy(*copySlice, *copyTargetSlice);
225225
bufferManager.getStream().synchronize();
226-
session.send(processIdx, copyTargetSlice->data(), sendSize);
226+
session.send(processIdx, copyTargetSlice->data(), copyTargetSlice->getSizeInBytes());
227227

228228
remainSendSize -= sendSize;
229229
}

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Dict, List, Optional, Union
1212

1313
import torch
14+
from cuda import cudart
1415

1516
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
1617
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
@@ -25,6 +26,7 @@
2526
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
2627
ReqIdsSet)
2728
from tensorrt_llm.logger import logger
29+
from tensorrt_llm.runtime.generation import CUASSERT
2830

2931
from ..distributed import Distributed
3032
from ..speculative.drafter import Drafter
@@ -644,6 +646,8 @@ def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):
644646
def _executor_loop_pp(self):
645647
logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
646648
torch.cuda.set_device(self.device_id)
649+
# ensure the context is created, otherwise, some MPI calls will fail.
650+
CUASSERT(cudart.cudaSetDevice(self.device_id))
647651
microbatch_id = 0
648652
with self._profiler() as profile_step:
649653
iter_start_time = time.time()
@@ -897,6 +901,8 @@ def _execute_guided_decoder(self, scheduled_batch, logits):
897901

898902
def _executor_loop(self):
899903
torch.cuda.set_device(self.device_id)
904+
# ensure the context is created, otherwise, some MPI calls will fail.
905+
CUASSERT(cudart.cudaSetDevice(self.device_id))
900906
with self._profiler() as profile_step:
901907
sample_state = None
902908
iter_start_time = time.time()
@@ -1014,6 +1020,8 @@ def _prepare_draft_requests(self):
10141020

10151021
def _executor_loop_overlap(self):
10161022
torch.cuda.set_device(self.device_id)
1023+
# ensure the context is created, otherwise, some MPI calls will fail.
1024+
CUASSERT(cudart.cudaSetDevice(self.device_id))
10171025
if self.dist.rank == 0 and not self.is_warmup and self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver:
10181026
while self.executor_request_queue.get_request_queue_size(
10191027
) < self.benchmark_req_queues_size:

0 commit comments

Comments
 (0)