@@ -389,8 +389,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
389
389
std::vector<int64_t > output_shape = {num_rows, hidden_size};
390
390
auto output = torch::empty (output_shape, input.options ().dtype (mOutputDtype ));
391
391
392
- WorkspaceInfo workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
393
- static_cast <int >(experts_per_token), base_activation_type, parallelism_config, min_latency_mode);
392
+ WorkspaceInfo const & workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
393
+ static_cast <int >(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream );
394
394
395
395
auto const quant_params = getQuantParams (num_experts_on_rank, hidden_size, inter_size, quant_scales);
396
396
kernels::MoeMinLatencyParams min_latency_params{};
@@ -547,8 +547,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
547
547
min_latency_params.experts_to_token_score = static_cast <float *>(experts_to_token_score.data_ptr ());
548
548
min_latency_params.active_expert_global_ids = static_cast <int *>(active_expert_global_ids.data_ptr ());
549
549
550
- WorkspaceInfo workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
551
- static_cast <int >(experts_per_token), base_activation_type, parallelism_config, min_latency_mode);
550
+ WorkspaceInfo const & workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
551
+ static_cast <int >(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream );
552
552
553
553
auto const quant_params = getQuantParams (num_experts_on_rank, hidden_size, inter_size, quant_scales);
554
554
@@ -702,6 +702,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
702
702
// e.g. 16 nvfp4 elements are packed into a single int64 element
703
703
int64_t mInnerDimMultiplier ;
704
704
char * mProfileWorkspace = nullptr ;
705
+ WorkspaceInfo workspace_info;
705
706
706
707
bool mUseDeepSeekFP8BlockScaling = false ;
707
708
bool mUseW4GroupScaling = false ;
@@ -750,9 +751,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
750
751
mKernelRunner ->setTactic (best_gemm1_profile, best_gemm2_profile);
751
752
}
752
753
753
- WorkspaceInfo getWorkspaceInfo (int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
754
+ WorkspaceInfo const & getWorkspaceInfo (int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
754
755
int num_experts, int experts_per_token, ActivationType activation_type,
755
- kernels::MOEParallelismConfig const & parallelismConfig, bool min_latency_mode)
756
+ kernels::MOEParallelismConfig const & parallelismConfig, bool min_latency_mode, cudaStream_t stream )
756
757
{
757
758
size_t moe_workspace_size = mKernelRunner ->getWorkspaceSize (num_rows, hidden_size, inter_size, num_experts,
758
759
experts_per_token, activation_type, parallelismConfig, /* use_lora */ false , mUseDeepSeekFP8BlockScaling ,
@@ -761,15 +762,29 @@ class FusedMoeRunner : public torch::CustomClassHolder
761
762
762
763
std::vector<size_t > workspaces{moe_workspace_size, src_to_dest_map_size};
763
764
764
- size_t total_workspace_size = common::calculateTotalWorkspaceSize (workspaces.data (), workspaces.size ());
765
+ int64_t const total_workspace_size = common::calculateTotalWorkspaceSize (workspaces.data (), workspaces.size ());
765
766
766
- WorkspaceInfo info{};
767
- info.workspace = torch::empty ({static_cast <long >(total_workspace_size)},
768
- torch::dtype (torch::kInt8 ).device (torch::kCUDA ).requires_grad (false ));
769
- info.src_to_dest_map
770
- = common::nextWorkspacePtr (static_cast <int8_t *>(info.workspace .data_ptr ()), moe_workspace_size);
767
+ bool is_capturing = tensorrt_llm::common::isCapturing (stream);
768
+ // Always allocate workspace when capturing cuda graph to avoid illegal memory access during replay
769
+ if (is_capturing || workspace_info.workspace .numel () < total_workspace_size)
770
+ {
771
+ if (is_capturing)
772
+ {
773
+ TLLM_LOG_DEBUG (
774
+ " Allocating MoE workspace with %ld bytes size during cuda graph capture" , total_workspace_size);
775
+ }
776
+ else
777
+ {
778
+ TLLM_LOG_DEBUG (" MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes" ,
779
+ workspace_info.workspace .numel (), total_workspace_size);
780
+ }
781
+ workspace_info.workspace = torch::empty ({static_cast <long >(total_workspace_size)},
782
+ torch::dtype (torch::kInt8 ).device (torch::kCUDA ).requires_grad (false ));
783
+ }
784
+ workspace_info.src_to_dest_map
785
+ = common::nextWorkspacePtr (static_cast <int8_t *>(workspace_info.workspace .data_ptr ()), moe_workspace_size);
771
786
772
- return info ;
787
+ return workspace_info ;
773
788
}
774
789
775
790
kernels::QuantParams getQuantParams (int64_t const num_experts_on_rank, int64_t const hidden_size,
0 commit comments