Skip to content

Commit 4c4141a

Browse files
nv-yilinfdominicshanshan
authored andcommitted
[https://nvbugs/5412562][feat] Allocate MoE workspace only when necessary (release/1.0 retargeted) (NVIDIA#6955)
Signed-off-by: Yilin Fan <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent d5875f4 commit 4c4141a

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

cpp/tensorrt_llm/thop/moeOp.cpp

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
389389
std::vector<int64_t> output_shape = {num_rows, hidden_size};
390390
auto output = torch::empty(output_shape, input.options().dtype(mOutputDtype));
391391

392-
WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
393-
static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode);
392+
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
393+
static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream);
394394

395395
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
396396
kernels::MoeMinLatencyParams min_latency_params{};
@@ -547,8 +547,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
547547
min_latency_params.experts_to_token_score = static_cast<float*>(experts_to_token_score.data_ptr());
548548
min_latency_params.active_expert_global_ids = static_cast<int*>(active_expert_global_ids.data_ptr());
549549

550-
WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
551-
static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode);
550+
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
551+
static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream);
552552

553553
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
554554

@@ -702,6 +702,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
702702
// e.g. 16 nvfp4 elements are packed into a single int64 element
703703
int64_t mInnerDimMultiplier;
704704
char* mProfileWorkspace = nullptr;
705+
WorkspaceInfo workspace_info;
705706

706707
bool mUseDeepSeekFP8BlockScaling = false;
707708
bool mUseW4GroupScaling = false;
@@ -750,9 +751,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
750751
mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile);
751752
}
752753

753-
WorkspaceInfo getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
754+
WorkspaceInfo const& getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
754755
int num_experts, int experts_per_token, ActivationType activation_type,
755-
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode)
756+
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode, cudaStream_t stream)
756757
{
757758
size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts,
758759
experts_per_token, activation_type, parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling,
@@ -761,15 +762,29 @@ class FusedMoeRunner : public torch::CustomClassHolder
761762

762763
std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};
763764

764-
size_t total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
765+
int64_t const total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
765766

766-
WorkspaceInfo info{};
767-
info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
768-
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
769-
info.src_to_dest_map
770-
= common::nextWorkspacePtr(static_cast<int8_t*>(info.workspace.data_ptr()), moe_workspace_size);
767+
bool is_capturing = tensorrt_llm::common::isCapturing(stream);
768+
// Always allocate workspace when capturing cuda graph to avoid illegal memory access during replay
769+
if (is_capturing || workspace_info.workspace.numel() < total_workspace_size)
770+
{
771+
if (is_capturing)
772+
{
773+
TLLM_LOG_DEBUG(
774+
"Allocating MoE workspace with %ld bytes size during cuda graph capture", total_workspace_size);
775+
}
776+
else
777+
{
778+
TLLM_LOG_DEBUG("MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes",
779+
workspace_info.workspace.numel(), total_workspace_size);
780+
}
781+
workspace_info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
782+
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
783+
}
784+
workspace_info.src_to_dest_map
785+
= common::nextWorkspacePtr(static_cast<int8_t*>(workspace_info.workspace.data_ptr()), moe_workspace_size);
771786

772-
return info;
787+
return workspace_info;
773788
}
774789

775790
kernels::QuantParams getQuantParams(int64_t const num_experts_on_rank, int64_t const hidden_size,

0 commit comments

Comments
 (0)