Skip to content

Commit 389b472

Browse files
chuangz0lancelly
authored andcommitted
use cudaSetDevice to create context ,fix nvbug 5394497 (NVIDIA#6403)
Signed-off-by: Chuang Zhu <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent c294c41 commit 389b472

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ void CacheFormatter::format(TransferSession& session)
347347
auto copyTargetSlice = runtime::ITensor::slice(preAllocSendBuffer, 0, sendSize);
348348
bufferManager.copy(*copySlice, *copyTargetSlice);
349349
bufferManager.getStream().synchronize();
350-
session.send(processIdx, copyTargetSlice->data(), sendSize);
350+
session.send(processIdx, copyTargetSlice->data(), copyTargetSlice->getSizeInBytes());
351351
remainSendSize -= sendSize;
352352
}
353353
}

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ void MLACacheFormatter::format(TransferSession& session)
223223
auto copyTargetSlice = runtime::ITensor::slice(preAllocSendBuffer, 0, sendSize);
224224
bufferManager.copy(*copySlice, *copyTargetSlice);
225225
bufferManager.getStream().synchronize();
226-
session.send(processIdx, copyTargetSlice->data(), sendSize);
226+
session.send(processIdx, copyTargetSlice->data(), copyTargetSlice->getSizeInBytes());
227227

228228
remainSendSize -= sendSize;
229229
}

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Dict, List, Optional, Union
1212

1313
import torch
14+
from cuda import cudart
1415

1516
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
1617
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
@@ -25,6 +26,7 @@
2526
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
2627
ReqIdsSet)
2728
from tensorrt_llm.logger import logger
29+
from tensorrt_llm.runtime.generation import CUASSERT
2830

2931
from ..distributed import Distributed
3032
from ..speculative.drafter import Drafter
@@ -634,6 +636,8 @@ def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):
634636
def _executor_loop_pp(self):
635637
logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
636638
torch.cuda.set_device(self.device_id)
639+
# ensure the context is created, otherwise, some MPI calls will fail.
640+
CUASSERT(cudart.cudaSetDevice(self.device_id))
637641
microbatch_id = 0
638642
with self._profiler() as profile_step:
639643
iter_start_time = time.time()
@@ -887,6 +891,8 @@ def _execute_guided_decoder(self, scheduled_batch, logits):
887891

888892
def _executor_loop(self):
889893
torch.cuda.set_device(self.device_id)
894+
# ensure the context is created, otherwise, some MPI calls will fail.
895+
CUASSERT(cudart.cudaSetDevice(self.device_id))
890896
with self._profiler() as profile_step:
891897
sample_state = None
892898
iter_start_time = time.time()
@@ -982,6 +988,8 @@ def _prepare_draft_requests(self):
982988

983989
def _executor_loop_overlap(self):
984990
torch.cuda.set_device(self.device_id)
991+
# ensure the context is created, otherwise, some MPI calls will fail.
992+
CUASSERT(cudart.cudaSetDevice(self.device_id))
985993
if self.dist.rank == 0 and not self.is_warmup and self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver:
986994
while self.executor_request_queue.get_request_queue_size(
987995
) < self.benchmark_req_queues_size:

0 commit comments

Comments
 (0)