Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,11 @@ def __init__(
assert len(
self.initial_local_expert_ids) == self.expert_size_per_partition

max_num_tokens = model_config.max_num_tokens
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
if self.use_dp:
max_num_tokens *= model_config.mapping.world_size
self.moe_max_num_tokens = model_config.moe_max_num_tokens or max_num_tokens
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
if self.moe_max_num_tokens < max_num_tokens:
if self.moe_max_num_tokens < moe_max_num_tokens:
self.aux_stream = aux_stream_dict[
AuxStreamType.
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(
Expand Down
260 changes: 241 additions & 19 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ...distributed import allgather
from ...model_config import ModelConfig
from ...utils import AuxStreamType, Fp4QuantizedTensor
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
from .fused_moe_cutlass import CutlassFusedMoE
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
MoEWeightLoadingMode, UnquantizedFusedMoEMethod)
Expand Down Expand Up @@ -88,6 +88,7 @@ def _masked_index_copy_group_quant_fp8(

def masked_index_copy_group_quant_fp8(
output: torch.Tensor,
output_s: torch.Tensor,
input: torch.Tensor,
start_offsets: torch.Tensor,
row_indices: torch.Tensor,
Expand All @@ -108,14 +109,10 @@ def masked_index_copy_group_quant_fp8(
col_size = output.shape[1]
dim_size = output.shape[2]

# create padded output_s
alignment = 4
scale_dim = (dim_size + group_size - 1) // group_size
padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment
padded_col_size = (col_size + alignment - 1) // alignment * alignment
output_s = torch.zeros((row_size, padded_dim_size // 4, padded_col_size),
dtype=torch.int32,
device='cuda')

# get block/grid/stage/warp
num_groups = (dim_size + group_size - 1) // group_size
Expand Down Expand Up @@ -247,17 +244,14 @@ def preprocess_after_permute(expert_first_token_offset_tensor,

@nvtx_range("[DG]")
def deepgemm_fp8_group_blockwise_gemm(
d: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
sfa: torch.Tensor,
sfb: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
) -> torch.Tensor:
d = torch.empty((a.shape[0], a.shape[1], b.shape[1]),
device=b.device,
dtype=torch.bfloat16)

# NOTES: shape must be `[G, M, K] @ [G, N, K].mT`
assert a.stride(-1) == 1
assert b.stride(-1) == 1
Expand Down Expand Up @@ -287,7 +281,16 @@ def deepgemm_fp8_group_blockwise_gemm(
masked_m,
expected_m,
disable_ue8m0_cast=True)
return d
return


def set_strides(workspace: torch.Tensor, g: int, m: int, k: int):
workspace = workspace[0:g * m * k]
workspace = workspace.as_strided(
size=(g, m, k),
stride=(m * k, k, 1),
)
return workspace


class DeepGemmFusedMoE(CutlassFusedMoE):
Expand Down Expand Up @@ -327,6 +330,18 @@ def __init__(
apply_router_weight_on_input: bool = False,
layer_idx: Optional[int] = None,
):
if model_config.moe_max_num_tokens is None:
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
# The default moe_max_num_tokens is calculated from the following formula:
# max_isl = 8196, max_batch_size = 1024, mtp = 0
# max_num_tokens = ((mtp+1)*max_batch_size+max_isl+128+63)//64*64 = 9344
# moe_max_num_tokens = max_num_tokens * 2 = 18688
# It can avoid OOM for 8k/1k cases.
default_moe_max_num_tokens = 18688
if moe_max_num_tokens > default_moe_max_num_tokens:
model_config._frozen = False
model_config.moe_max_num_tokens = default_moe_max_num_tokens
model_config._frozen = True

super().__init__(
routing_method=routing_method,
Expand All @@ -342,6 +357,37 @@ def __init__(
layer_idx=layer_idx,
)

def get_workspace(self, m_max: int, group_size: int):
hidden_size = self.hidden_size
intermediate_size = self.intermediate_size
num_experts = self.expert_size_per_partition

# create workspace
fp8_dim = max(hidden_size, intermediate_size)
workspace_0 = torch.empty((num_experts * m_max * fp8_dim),
dtype=torch.float8_e4m3fn,
device='cuda')
workspace_1 = torch.empty(
(num_experts * m_max * max(intermediate_size * 2, hidden_size)),
dtype=torch.bfloat16,
device='cuda')

# create workspace for scaling factors
m_padded = fp8_utils.align(m_max, 4)
scale_k = fp8_utils.ceil_div(fp8_dim, group_size)
scale_k_padded = fp8_utils.align(scale_k, 4)
workspace_sf = torch.empty(
(num_experts * (scale_k_padded // 4) * m_padded),
dtype=torch.int32,
device='cuda')

workspace = {
"workspace_0": workspace_0,
"workspace_1": workspace_1,
"workspace_sf": workspace_sf,
}
return workspace

def _get_quant_method(self):
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True):
Expand All @@ -362,6 +408,7 @@ def forward_chunk(
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
workspace: Optional[dict] = None,
) -> torch.Tensor:
if isinstance(x, Fp4QuantizedTensor):
assert output_dtype is not None
Expand Down Expand Up @@ -437,32 +484,72 @@ def forward_chunk(
masked_m, token_to_expert_map = preprocess_after_permute(
expert_first_token_offset_tensor, permuted_data_tensor)

m_max = (x.shape[0] + 127) // 128 * 128
expected_m = (token_selected_experts.numel() +
self.expert_size_per_partition -
1) // self.expert_size_per_partition
act_input_fp8 = torch.empty(
(self.expert_size_per_partition, m_max, self.hidden_size),
dtype=torch.float8_e4m3fn,
device='cuda')

# padding and quantization
m_max = fp8_utils.align(x.shape[0], 128)
act_input_fp8 = set_strides(workspace["workspace_0"],
self.expert_size_per_partition, m_max,
self.hidden_size)

m_padded = fp8_utils.align(m_max, 4)
scale_k = fp8_utils.ceil_div(self.hidden_size, 128)
scale_k_padded = fp8_utils.align(scale_k, 4)
act_input_sf = set_strides(workspace["workspace_sf"],
self.expert_size_per_partition,
scale_k_padded // 4, m_padded)

act_input_sf = masked_index_copy_group_quant_fp8(
act_input_fp8,
act_input_sf,
permuted_data_tensor,
expert_first_token_offset_tensor,
token_to_expert_map,
group_size=128)

h1 = deepgemm_fp8_group_blockwise_gemm(
# grouped gemm 1
h1 = set_strides(workspace["workspace_1"],
self.expert_size_per_partition, m_max,
self.intermediate_size * 2)

deepgemm_fp8_group_blockwise_gemm(
d=h1,
a=act_input_fp8,
b=self.w3_w1_weight,
sfa=act_input_sf,
sfb=self.quant_scales[0],
masked_m=masked_m,
expected_m=expected_m,
)
act_input_fp8, act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd(
input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True)
h3 = deepgemm_fp8_group_blockwise_gemm(

# activation and quantization
act_input_fp8 = set_strides(workspace["workspace_0"],
self.expert_size_per_partition, m_max,
self.intermediate_size)

scale_k = fp8_utils.ceil_div(self.intermediate_size, 128)
scale_k_padded = fp8_utils.align(scale_k, 4)
act_input_sf = set_strides(workspace["workspace_sf"],
self.expert_size_per_partition,
scale_k_padded // 4, m_padded)

act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd(
output=act_input_fp8,
output_scale=act_input_sf,
input=h1,
quant_group_size=128,
masked_m=masked_m,
scale_ue8m0=True)

# grouped gemm 2
h3 = set_strides(workspace["workspace_1"],
self.expert_size_per_partition, m_max,
self.hidden_size)

deepgemm_fp8_group_blockwise_gemm(
d=h3,
a=act_input_fp8,
b=self.w2_weight,
sfa=act_input_sf,
Expand All @@ -471,6 +558,7 @@ def forward_chunk(
expected_m=expected_m,
)

# gather and finalize
triton_masked_index_gather(permuted_data_tensor, h3,
expert_first_token_offset_tensor,
token_to_expert_map)
Expand All @@ -495,3 +583,137 @@ def forward_chunk(
)

return final_hidden_states

def forward(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
do_finalize: bool = True, # used by other MoE backends
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
use_dp_padding: Optional[bool] = None,
) -> torch.Tensor:
assert do_finalize, "CutlassFusedMoE does not support do_finalize=False"
if self.use_dp and self.parallel_size > 1:
assert all_rank_num_tokens is not None
assert use_dp_padding is not None
num_rows = sum(all_rank_num_tokens)
else:
num_rows = x.shape[0]

# In case of num_rows is larger than max_chunk_size * 2, we need to split the input into multiple chunks.
# Because we will use two streams in chunked moe and preallocate two workspaces.
num_chunks = 1
if num_rows > self.moe_max_num_tokens * 2:
num_chunks = (num_rows + self.moe_max_num_tokens -
1) // self.moe_max_num_tokens

if use_dp_padding:
all_rank_num_tokens_padded = [all_rank_max_num_tokens
] * len(all_rank_num_tokens)
else:
all_rank_num_tokens_padded = all_rank_num_tokens

if num_chunks == 1:
# create workspace
num_rows = x.shape[0]
if self.use_dp:
num_rows = sum(all_rank_num_tokens_padded)
m_max = fp8_utils.align(num_rows, 128)
workspace = self.get_workspace(m_max, 128)
outputs = self.forward_chunk(
x,
router_logits,
output_dtype,
all_rank_num_tokens=all_rank_num_tokens_padded,
use_dp_padding=use_dp_padding,
workspace=workspace)
outputs = self.reducescatter_or_allreduce(
outputs,
all_rank_num_tokens=all_rank_num_tokens_padded,
use_dp_padding=use_dp_padding)
else:
if self.use_dp:
all_rank_chunk_size_list = [
self.split_chunk(val, num_chunks)
for val in all_rank_num_tokens_padded
]
all_rank_num_tokens_list = [[
val[idx_chunk] for val in all_rank_chunk_size_list
] for idx_chunk in range(num_chunks)]
chunk_size_list = all_rank_chunk_size_list[self.rank]
else:
all_rank_num_tokens_list = [None] * num_chunks
chunk_size_list = self.split_chunk(x.shape[0], num_chunks)

# create workspace
chunk_size_0 = sum(all_rank_num_tokens_list[0]
) if self.use_dp else chunk_size_list[0]
chunk_size_1 = sum(all_rank_num_tokens_list[1]
) if self.use_dp else chunk_size_list[1]
workspace_0 = self.get_workspace(fp8_utils.align(chunk_size_0, 128),
128)
workspace_1 = self.get_workspace(fp8_utils.align(chunk_size_1, 128),
128)

x_list = x.split(chunk_size_list)
router_logits_list = router_logits.split(chunk_size_list)

self.event_dict[EventType.Main].record()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()

def _forward_chunk(x_, router_logits_, idx, workspace):
return self.forward_chunk(
x_,
router_logits_,
all_rank_num_tokens=all_rank_num_tokens_list[idx]
if self.use_dp else None,
use_dp_padding=use_dp_padding,
workspace=workspace)

def _reducescatter_or_allreduce(x_, idx):
return self.reducescatter_or_allreduce(
x_,
all_rank_num_tokens=all_rank_num_tokens_list[idx],
use_dp_padding=use_dp_padding)

outputs_list = []
# Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
for idx_chunk, (x, router_logits) in enumerate(
zip(x_list, router_logits_list)):

if idx_chunk % 2 == 0:
with torch.cuda.stream(self.aux_stream):
outputs = _forward_chunk(x, router_logits, idx_chunk,
workspace_0)
if idx_chunk > 0:
outputs_list[-1] = _reducescatter_or_allreduce(
outputs_list[-1], idx_chunk - 1)
else:
outputs = _forward_chunk(x, router_logits, idx_chunk,
workspace_1)
with torch.cuda.stream(self.aux_stream):
outputs_list[-1] = _reducescatter_or_allreduce(
outputs_list[-1], idx_chunk - 1)

outputs_list.append(outputs)

if num_chunks % 2 == 0:
outputs_list[-1] = _reducescatter_or_allreduce(
outputs_list[-1], -1)
else:
with torch.cuda.stream(self.aux_stream):
outputs_list[-1] = _reducescatter_or_allreduce(
outputs_list[-1], -1)
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.MoeChunkingOverlap].record()
self.event_dict[EventType.MoeChunkingOverlap].wait()

outputs = torch.cat(outputs_list)

if self.use_dp and self.parallel_size > 1:
rank = self.mapping.tp_rank
outputs = outputs[:all_rank_num_tokens[rank]]
return outputs
8 changes: 2 additions & 6 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,9 @@ def __init__(
self.num_experts)
self.expert_size_per_partition = self.expert_end - self.expert_start

max_num_tokens = model_config.max_num_tokens
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
if self.use_dp:
max_num_tokens *= model_config.mapping.world_size
self.moe_max_num_tokens = (model_config.moe_max_num_tokens
if model_config.moe_max_num_tokens
is not None else max_num_tokens)
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens

self._weights_created = False
if not model_config.skip_create_weights_in_init:
Expand Down
Loading