|
22 | 22 | from torch import nn
|
23 | 23 | from transformers import PretrainedConfig
|
24 | 24 | from vllm.compilation.decorators import support_torch_compile
|
25 |
| -from vllm.config import CacheConfig, VllmConfig |
| 25 | +from vllm.config import CacheConfig, CompilationLevel, VllmConfig |
26 | 26 | from vllm.distributed import get_pp_group
|
27 | 27 | from vllm.model_executor.layers.layernorm import RMSNorm
|
28 | 28 | from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
32 | 32 | from vllm.model_executor.models.interfaces import SupportsPP
|
33 | 33 | from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
|
34 | 34 | Qwen3MoeForCausalLM,
|
35 |
| - Qwen3MoeMLP, Qwen3MoeModel) |
| 35 | + Qwen3MoeMLP, Qwen3MoeModel, |
| 36 | + Qwen3MoeSparseMoeBlock) |
36 | 37 | from vllm.model_executor.models.utils import (
|
37 | 38 | extract_layer_index, make_empty_intermediate_tensors_factory, make_layers,
|
38 | 39 | maybe_prefix)
|
@@ -78,12 +79,21 @@ def __init__(
|
78 | 79 | layer_idx = extract_layer_index(prefix)
|
79 | 80 | mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
80 | 81 | config.mlp_only_layers)
|
| 82 | + use_aclgraph = (vllm_config is not None |
| 83 | + and vllm_config.compilation_config.level |
| 84 | + == CompilationLevel.PIECEWISE |
| 85 | + and not vllm_config.model_config.enforce_eager) |
81 | 86 | if (layer_idx not in mlp_only_layers) and (
|
82 | 87 | config.num_experts > 0 and
|
83 | 88 | (layer_idx + 1) % config.decoder_sparse_step == 0):
|
84 |
| - self.mlp = AscendSparseMoeBlock(config=config, |
85 |
| - quant_config=quant_config, |
86 |
| - prefix=f"{prefix}.mlp") |
| 89 | + if not use_aclgraph: |
| 90 | + self.mlp = AscendSparseMoeBlock(config=config, |
| 91 | + quant_config=quant_config, |
| 92 | + prefix=f"{prefix}.mlp") |
| 93 | + else: |
| 94 | + self.mlp = Qwen3MoeSparseMoeBlock(config=config, |
| 95 | + quant_config=quant_config, |
| 96 | + prefix=f"{prefix}.mlp") |
87 | 97 | else:
|
88 | 98 | self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
|
89 | 99 | intermediate_size=config.intermediate_size,
|
|
0 commit comments