Skip to content

Commit 2af5a71

Browse files
committed
add deepep
1 parent 8f68177 commit 2af5a71

File tree

9 files changed

+838
-13
lines changed

9 files changed

+838
-13
lines changed

lmdeploy/cli/serve.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def add_parser_api_server():
165165
quant_policy = ArgumentHelper.quant_policy(pt_group)
166166
ArgumentHelper.dp(pt_group)
167167
ArgumentHelper.dp_rank(pt_group)
168+
ArgumentHelper.ep(pt_group)
168169

169170
# turbomind args
170171
tb_group = parser.add_argument_group('TurboMind engine arguments')
@@ -296,6 +297,7 @@ def api_server(args):
296297
tp=args.tp,
297298
dp=args.dp,
298299
dp_rank=args.dp_rank,
300+
ep=args.ep,
299301
max_batch_size=max_batch_size,
300302
cache_max_entry_count=args.cache_max_entry_count,
301303
block_size=args.cache_block_seq_len,

lmdeploy/cli/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,12 @@ def dp(parser):
157157
"""Add argument dp to parser."""
158158

159159
return parser.add_argument('--dp', type=int, default=1, help='data parallelism. dp_rank is required.')
160+
161+
@staticmethod
162+
def ep(parser):
163+
"""Add argument ep to parser."""
164+
165+
return parser.add_argument('--ep', type=int, default=1, help='expert parallelism. Should be 2^n.')
160166

161167
@staticmethod
162168
def dp_rank(parser):

lmdeploy/messages.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ class PytorchEngineConfig:
294294
tp: int = 1
295295
dp: int = 1
296296
dp_rank: int = 0
297+
ep: int = 1
297298
session_len: int = None
298299
max_batch_size: int = None
299300
cache_max_entry_count: float = 0.8
@@ -318,6 +319,7 @@ def __post_init__(self):
318319
assert self.dtype in ['auto', 'float16', 'bfloat16']
319320
assert self.tp >= 1, 'invalid tp'
320321
assert self.dp >= 1, 'invalid dp'
322+
assert self.ep >= 1, 'invalid ep'
321323
assert 0 < self.cache_max_entry_count < 1, \
322324
'invalid cache_max_entry_count'
323325
assert self.num_cpu_blocks >= 0, 'invalid num_cpu_blocks'

lmdeploy/pytorch/backends/cuda/moe.py

Lines changed: 180 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
from typing import List
44

55
import torch
6-
6+
import torch.distributed as dist
7+
from lmdeploy.pytorch.backends.cuda.token_dispatcher import DeepEPDispatcher
8+
from lmdeploy.pytorch.distributed import get_dist_manager
79
from lmdeploy.pytorch.kernels.cuda import fused_moe, fused_moe_w8a8
810
from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8
911
from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
12+
from lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize
13+
from lmdeploy.pytorch.kernels.cuda.ep_moe import grouped_gemm_triton, silu_and_mul_triton_kernel
1014
from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import per_token_quant_int8
1115
from lmdeploy.pytorch.models.q_modules import QTensor
1216

@@ -227,18 +231,188 @@ def forward(self,
227231
return output
228232

229233

234+
235+
class DeepEPMoE:
236+
"""
237+
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
238+
"""
239+
240+
def __init__(
241+
self,
242+
num_experts: int,
243+
ep_size: int,
244+
block_shape: list[int],
245+
):
246+
self.num_experts = num_experts
247+
self.ep_size = ep_size
248+
assert self.num_experts % self.ep_size == 0
249+
self.num_experts_per_partition = self.num_experts // self.ep_size
250+
self.block_shape = block_shape
251+
self.use_fp8_w8a8 = True
252+
253+
def forward(
254+
self,
255+
hidden_states: torch.Tensor,
256+
tokens_per_expert: torch.Tensor,
257+
gate_up_weight:torch.Tensor,
258+
gate_up_scale:torch.Tensor,
259+
gate_down_weight:torch.Tensor,
260+
gate_down_scale:torch.Tensor
261+
):
262+
seg_indptr_cur_rank = torch.cat(
263+
[
264+
torch.zeros(
265+
1, device=tokens_per_expert.device, dtype=tokens_per_expert.dtype
266+
),
267+
torch.cumsum(tokens_per_expert, dim=0),
268+
]
269+
)
270+
reorder_topk_ids = torch.repeat_interleave(tokens_per_expert)
271+
weight_indices_cur_rank = torch.arange(
272+
0,
273+
self.num_experts_per_partition,
274+
device=hidden_states.device,
275+
dtype=torch.int64,
276+
)
277+
278+
# GroupGemm-0
279+
gateup_output = torch.empty(
280+
hidden_states.shape[0],
281+
gate_up_weight.shape[1],
282+
device=hidden_states.device,
283+
dtype=hidden_states.dtype,
284+
)
285+
if hidden_states.shape[0] > 0:
286+
input, input_scale = quant_fp8(hidden_states, 128, dtype=gate_up_weight.dtype)
287+
gateup_output = grouped_gemm_triton(
288+
a=input,
289+
b=gate_up_weight,
290+
c=gateup_output,
291+
batch_size=self.num_experts_per_partition,
292+
weight_column_major=True,
293+
seg_indptr=seg_indptr_cur_rank,
294+
weight_indices=weight_indices_cur_rank,
295+
use_fp8_w8a8=self.use_fp8_w8a8,
296+
scale_a=input_scale,
297+
scale_b=gate_up_scale,
298+
block_shape=self.block_shape,
299+
)
300+
301+
# Act
302+
down_input = torch.empty(
303+
gateup_output.shape[0],
304+
gateup_output.shape[1] // 2,
305+
device=gateup_output.device,
306+
dtype=hidden_states.dtype,
307+
)
308+
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
309+
gateup_output,
310+
down_input,
311+
gateup_output.shape[1],
312+
reorder_topk_ids,
313+
None,
314+
0,
315+
self.num_experts_per_partition - 1,
316+
BLOCK_SIZE=512,
317+
)
318+
319+
# GroupGemm-1
320+
down_output = torch.empty(
321+
down_input.shape[0],
322+
gate_down_weight.shape[1],
323+
device=hidden_states.device,
324+
dtype=hidden_states.dtype,
325+
)
326+
if down_input.shape[0] > 0:
327+
down_input, down_input_scale = quant_fp8(down_input, 128, dtype=gate_down_weight.dtype)
328+
down_output = grouped_gemm_triton(
329+
a=down_input,
330+
b=gate_down_weight,
331+
c=down_output,
332+
batch_size=self.num_experts_per_partition,
333+
weight_column_major=True,
334+
seg_indptr=seg_indptr_cur_rank,
335+
weight_indices=weight_indices_cur_rank,
336+
use_fp8_w8a8=self.use_fp8_w8a8,
337+
scale_a=down_input_scale,
338+
scale_b=gate_down_scale,
339+
block_shape=self.block_shape,
340+
)
341+
return down_output
342+
343+
344+
class FusedDeepEpMoEBlockedF8Impl(TritonFusedMoEBlockedF8Impl):
345+
def __init__(self,
346+
ep_size: int,
347+
ep_group:dist.ProcessGroup,
348+
top_k: int,
349+
num_experts: int,
350+
hidden_dim: int,
351+
renormalize: bool = False,
352+
block_size: int = 128,
353+
out_dtype: torch.dtype = torch.bfloat16):
354+
super().__init__(top_k, num_experts, renormalize, block_size, out_dtype)
355+
self.token_dispatcher = DeepEPDispatcher(
356+
group=ep_group,
357+
router_topk=self.top_k,
358+
permute_fusion=True,
359+
num_experts=self.num_experts,
360+
num_local_experts=self.num_experts // ep_size,
361+
hidden_size=hidden_dim,
362+
params_dtype=out_dtype,
363+
)
364+
self.experts = DeepEPMoE(num_experts, ep_size, [block_size,block_size])
365+
366+
def forward(self,
367+
hidden_states: torch.Tensor,
368+
topk_weights: torch.Tensor,
369+
topk_ids: torch.LongTensor,
370+
gate_up_weights: torch.Tensor,
371+
gate_up_scale: torch.Tensor,
372+
down_weights: torch.Tensor,
373+
down_scale: torch.Tensor,
374+
expert_list: List[int] = None):
375+
"""forward."""
376+
topk_weights = _renormalize(topk_weights, self.renormalize)
377+
recv_hidden_states, recv_topk_ids, recv_topk_weights, tokens_per_expert = (
378+
self.token_dispatcher.dispatch(
379+
hidden_states,
380+
topk_ids.to(torch.int32),
381+
topk_weights.to(torch.float32),
382+
self.num_experts,
383+
)
384+
)
385+
out_states = self.experts.forward(recv_hidden_states, tokens_per_expert, gate_up_weights, gate_up_scale,
386+
down_weights, down_scale)
387+
out_states = self.token_dispatcher.combine(out_states)
388+
return out_states
389+
230390
class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):
231391
"""triton fused moe blocked f8 builder."""
232392

233393
@staticmethod
234394
def build(top_k: int,
235395
num_experts: int,
396+
hidden_dim: int,
236397
renormalize: bool = False,
237398
block_size: int = 128,
399+
ep_size: int = 1,
400+
ep_group: dist.ProcessGroup = None,
238401
out_dtype: torch.dtype = torch.float16):
239402
"""build from mlp."""
240-
return TritonFusedMoEBlockedF8Impl(top_k=top_k,
241-
num_experts=num_experts,
242-
renormalize=renormalize,
243-
block_size=block_size,
244-
out_dtype=out_dtype)
403+
if ep_size > 1:
404+
return FusedDeepEpMoEBlockedF8Impl(ep_size=ep_size,
405+
ep_group=ep_group,
406+
top_k=top_k,
407+
num_experts=num_experts,
408+
hidden_dim=hidden_dim,
409+
renormalize=renormalize,
410+
block_size=block_size,
411+
out_dtype=out_dtype)
412+
else:
413+
return TritonFusedMoEBlockedF8Impl(top_k=top_k,
414+
num_experts=num_experts,
415+
renormalize=renormalize,
416+
block_size=block_size,
417+
out_dtype=out_dtype)
418+

0 commit comments

Comments
 (0)