|
3 | 3 | from typing import List
|
4 | 4 |
|
5 | 5 | import torch
|
6 |
| - |
| 6 | +import torch.distributed as dist |
| 7 | +from lmdeploy.pytorch.backends.cuda.token_dispatcher import DeepEPDispatcher |
| 8 | +from lmdeploy.pytorch.distributed import get_dist_manager |
7 | 9 | from lmdeploy.pytorch.kernels.cuda import fused_moe, fused_moe_w8a8
|
8 | 10 | from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8
|
9 | 11 | from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
|
| 12 | +from lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize |
| 13 | +from lmdeploy.pytorch.kernels.cuda.ep_moe import grouped_gemm_triton, silu_and_mul_triton_kernel |
10 | 14 | from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import per_token_quant_int8
|
11 | 15 | from lmdeploy.pytorch.models.q_modules import QTensor
|
12 | 16 |
|
@@ -227,18 +231,188 @@ def forward(self,
|
227 | 231 | return output
|
228 | 232 |
|
229 | 233 |
|
| 234 | + |
| 235 | +class DeepEPMoE: |
| 236 | + """ |
| 237 | + MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main) |
| 238 | + """ |
| 239 | + |
| 240 | + def __init__( |
| 241 | + self, |
| 242 | + num_experts: int, |
| 243 | + ep_size: int, |
| 244 | + block_shape: list[int], |
| 245 | + ): |
| 246 | + self.num_experts = num_experts |
| 247 | + self.ep_size = ep_size |
| 248 | + assert self.num_experts % self.ep_size == 0 |
| 249 | + self.num_experts_per_partition = self.num_experts // self.ep_size |
| 250 | + self.block_shape = block_shape |
| 251 | + self.use_fp8_w8a8 = True |
| 252 | + |
| 253 | + def forward( |
| 254 | + self, |
| 255 | + hidden_states: torch.Tensor, |
| 256 | + tokens_per_expert: torch.Tensor, |
| 257 | + gate_up_weight:torch.Tensor, |
| 258 | + gate_up_scale:torch.Tensor, |
| 259 | + gate_down_weight:torch.Tensor, |
| 260 | + gate_down_scale:torch.Tensor |
| 261 | + ): |
| 262 | + seg_indptr_cur_rank = torch.cat( |
| 263 | + [ |
| 264 | + torch.zeros( |
| 265 | + 1, device=tokens_per_expert.device, dtype=tokens_per_expert.dtype |
| 266 | + ), |
| 267 | + torch.cumsum(tokens_per_expert, dim=0), |
| 268 | + ] |
| 269 | + ) |
| 270 | + reorder_topk_ids = torch.repeat_interleave(tokens_per_expert) |
| 271 | + weight_indices_cur_rank = torch.arange( |
| 272 | + 0, |
| 273 | + self.num_experts_per_partition, |
| 274 | + device=hidden_states.device, |
| 275 | + dtype=torch.int64, |
| 276 | + ) |
| 277 | + |
| 278 | + # GroupGemm-0 |
| 279 | + gateup_output = torch.empty( |
| 280 | + hidden_states.shape[0], |
| 281 | + gate_up_weight.shape[1], |
| 282 | + device=hidden_states.device, |
| 283 | + dtype=hidden_states.dtype, |
| 284 | + ) |
| 285 | + if hidden_states.shape[0] > 0: |
| 286 | + input, input_scale = quant_fp8(hidden_states, 128, dtype=gate_up_weight.dtype) |
| 287 | + gateup_output = grouped_gemm_triton( |
| 288 | + a=input, |
| 289 | + b=gate_up_weight, |
| 290 | + c=gateup_output, |
| 291 | + batch_size=self.num_experts_per_partition, |
| 292 | + weight_column_major=True, |
| 293 | + seg_indptr=seg_indptr_cur_rank, |
| 294 | + weight_indices=weight_indices_cur_rank, |
| 295 | + use_fp8_w8a8=self.use_fp8_w8a8, |
| 296 | + scale_a=input_scale, |
| 297 | + scale_b=gate_up_scale, |
| 298 | + block_shape=self.block_shape, |
| 299 | + ) |
| 300 | + |
| 301 | + # Act |
| 302 | + down_input = torch.empty( |
| 303 | + gateup_output.shape[0], |
| 304 | + gateup_output.shape[1] // 2, |
| 305 | + device=gateup_output.device, |
| 306 | + dtype=hidden_states.dtype, |
| 307 | + ) |
| 308 | + silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( |
| 309 | + gateup_output, |
| 310 | + down_input, |
| 311 | + gateup_output.shape[1], |
| 312 | + reorder_topk_ids, |
| 313 | + None, |
| 314 | + 0, |
| 315 | + self.num_experts_per_partition - 1, |
| 316 | + BLOCK_SIZE=512, |
| 317 | + ) |
| 318 | + |
| 319 | + # GroupGemm-1 |
| 320 | + down_output = torch.empty( |
| 321 | + down_input.shape[0], |
| 322 | + gate_down_weight.shape[1], |
| 323 | + device=hidden_states.device, |
| 324 | + dtype=hidden_states.dtype, |
| 325 | + ) |
| 326 | + if down_input.shape[0] > 0: |
| 327 | + down_input, down_input_scale = quant_fp8(down_input, 128, dtype=gate_down_weight.dtype) |
| 328 | + down_output = grouped_gemm_triton( |
| 329 | + a=down_input, |
| 330 | + b=gate_down_weight, |
| 331 | + c=down_output, |
| 332 | + batch_size=self.num_experts_per_partition, |
| 333 | + weight_column_major=True, |
| 334 | + seg_indptr=seg_indptr_cur_rank, |
| 335 | + weight_indices=weight_indices_cur_rank, |
| 336 | + use_fp8_w8a8=self.use_fp8_w8a8, |
| 337 | + scale_a=down_input_scale, |
| 338 | + scale_b=gate_down_scale, |
| 339 | + block_shape=self.block_shape, |
| 340 | + ) |
| 341 | + return down_output |
| 342 | + |
| 343 | + |
| 344 | +class FusedDeepEpMoEBlockedF8Impl(TritonFusedMoEBlockedF8Impl): |
| 345 | + def __init__(self, |
| 346 | + ep_size: int, |
| 347 | + ep_group:dist.ProcessGroup, |
| 348 | + top_k: int, |
| 349 | + num_experts: int, |
| 350 | + hidden_dim: int, |
| 351 | + renormalize: bool = False, |
| 352 | + block_size: int = 128, |
| 353 | + out_dtype: torch.dtype = torch.bfloat16): |
| 354 | + super().__init__(top_k, num_experts, renormalize, block_size, out_dtype) |
| 355 | + self.token_dispatcher = DeepEPDispatcher( |
| 356 | + group=ep_group, |
| 357 | + router_topk=self.top_k, |
| 358 | + permute_fusion=True, |
| 359 | + num_experts=self.num_experts, |
| 360 | + num_local_experts=self.num_experts // ep_size, |
| 361 | + hidden_size=hidden_dim, |
| 362 | + params_dtype=out_dtype, |
| 363 | + ) |
| 364 | + self.experts = DeepEPMoE(num_experts, ep_size, [block_size,block_size]) |
| 365 | + |
| 366 | + def forward(self, |
| 367 | + hidden_states: torch.Tensor, |
| 368 | + topk_weights: torch.Tensor, |
| 369 | + topk_ids: torch.LongTensor, |
| 370 | + gate_up_weights: torch.Tensor, |
| 371 | + gate_up_scale: torch.Tensor, |
| 372 | + down_weights: torch.Tensor, |
| 373 | + down_scale: torch.Tensor, |
| 374 | + expert_list: List[int] = None): |
| 375 | + """forward.""" |
| 376 | + topk_weights = _renormalize(topk_weights, self.renormalize) |
| 377 | + recv_hidden_states, recv_topk_ids, recv_topk_weights, tokens_per_expert = ( |
| 378 | + self.token_dispatcher.dispatch( |
| 379 | + hidden_states, |
| 380 | + topk_ids.to(torch.int32), |
| 381 | + topk_weights.to(torch.float32), |
| 382 | + self.num_experts, |
| 383 | + ) |
| 384 | + ) |
| 385 | + out_states = self.experts.forward(recv_hidden_states, tokens_per_expert, gate_up_weights, gate_up_scale, |
| 386 | + down_weights, down_scale) |
| 387 | + out_states = self.token_dispatcher.combine(out_states) |
| 388 | + return out_states |
| 389 | + |
230 | 390 | class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):
|
231 | 391 | """triton fused moe blocked f8 builder."""
|
232 | 392 |
|
233 | 393 | @staticmethod
|
234 | 394 | def build(top_k: int,
|
235 | 395 | num_experts: int,
|
| 396 | + hidden_dim: int, |
236 | 397 | renormalize: bool = False,
|
237 | 398 | block_size: int = 128,
|
| 399 | + ep_size: int = 1, |
| 400 | + ep_group: dist.ProcessGroup = None, |
238 | 401 | out_dtype: torch.dtype = torch.float16):
|
239 | 402 | """build from mlp."""
|
240 |
| - return TritonFusedMoEBlockedF8Impl(top_k=top_k, |
241 |
| - num_experts=num_experts, |
242 |
| - renormalize=renormalize, |
243 |
| - block_size=block_size, |
244 |
| - out_dtype=out_dtype) |
| 403 | + if ep_size > 1: |
| 404 | + return FusedDeepEpMoEBlockedF8Impl(ep_size=ep_size, |
| 405 | + ep_group=ep_group, |
| 406 | + top_k=top_k, |
| 407 | + num_experts=num_experts, |
| 408 | + hidden_dim=hidden_dim, |
| 409 | + renormalize=renormalize, |
| 410 | + block_size=block_size, |
| 411 | + out_dtype=out_dtype) |
| 412 | + else: |
| 413 | + return TritonFusedMoEBlockedF8Impl(top_k=top_k, |
| 414 | + num_experts=num_experts, |
| 415 | + renormalize=renormalize, |
| 416 | + block_size=block_size, |
| 417 | + out_dtype=out_dtype) |
| 418 | + |
0 commit comments