Skip to content

Commit 542f552

Browse files
authored
use cudaSetDevice to create context ,fix nvbug 5394497 (#6403)
Signed-off-by: Chuang Zhu <[email protected]>
1 parent 3f7abf8 commit 542f552

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ void CacheFormatter::format(TransferSession& session)
347347
auto copyTargetSlice = runtime::ITensor::slice(preAllocSendBuffer, 0, sendSize);
348348
bufferManager.copy(*copySlice, *copyTargetSlice);
349349
bufferManager.getStream().synchronize();
350-
session.send(processIdx, copyTargetSlice->data(), sendSize);
350+
session.send(processIdx, copyTargetSlice->data(), copyTargetSlice->getSizeInBytes());
351351
remainSendSize -= sendSize;
352352
}
353353
}

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ void MLACacheFormatter::format(TransferSession& session)
223223
auto copyTargetSlice = runtime::ITensor::slice(preAllocSendBuffer, 0, sendSize);
224224
bufferManager.copy(*copySlice, *copyTargetSlice);
225225
bufferManager.getStream().synchronize();
226-
session.send(processIdx, copyTargetSlice->data(), sendSize);
226+
session.send(processIdx, copyTargetSlice->data(), copyTargetSlice->getSizeInBytes());
227227

228228
remainSendSize -= sendSize;
229229
}

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Dict, List, Optional, Union
1212

1313
import torch
14+
from cuda import cudart
1415

1516
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
1617
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
@@ -25,6 +26,7 @@
2526
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
2627
ReqIdsSet)
2728
from tensorrt_llm.logger import logger
29+
from tensorrt_llm.runtime.generation import CUASSERT
2830

2931
from ..distributed import Distributed
3032
from ..speculative.drafter import Drafter
@@ -644,6 +646,8 @@ def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):
644646
def _executor_loop_pp(self):
645647
logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
646648
torch.cuda.set_device(self.device_id)
649+
# ensure the context is created, otherwise, some MPI calls will fail.
650+
CUASSERT(cudart.cudaSetDevice(self.device_id))
647651
microbatch_id = 0
648652
with self._profiler() as profile_step:
649653
iter_start_time = time.time()
@@ -897,6 +901,8 @@ def _execute_guided_decoder(self, scheduled_batch, logits):
897901

898902
def _executor_loop(self):
899903
torch.cuda.set_device(self.device_id)
904+
# ensure the context is created, otherwise, some MPI calls will fail.
905+
CUASSERT(cudart.cudaSetDevice(self.device_id))
900906
with self._profiler() as profile_step:
901907
sample_state = None
902908
iter_start_time = time.time()
@@ -992,6 +998,8 @@ def _prepare_draft_requests(self):
992998

993999
def _executor_loop_overlap(self):
9941000
torch.cuda.set_device(self.device_id)
1001+
# ensure the context is created, otherwise, some MPI calls will fail.
1002+
CUASSERT(cudart.cudaSetDevice(self.device_id))
9951003
if self.dist.rank == 0 and not self.is_warmup and self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver:
9961004
while self.executor_request_queue.get_request_queue_size(
9971005
) < self.benchmark_req_queues_size:

0 commit comments

Comments
 (0)