Skip to content

Commit 8659543

Browse files
authored
[compiler toolkit] Prepare deepseek to accept graph passes (#1982)
Made some updates to improve UX when running experiments in compiler toolkit - Always register block mask as pytree node. A model could use flex_attn even it's flavor doesn't contain `flex_attn` - Prepare deepseek v3 to accept graph passes like llama3 - Annotate flex attention in deepseek v3 - Regional inductor doesn't work on deepseek with flex attn with error P2021796847 To repro the regional inductor issue in dsv3, uncomment `regional_inductor()` and run ``` NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --model.flavor=debugmodel_flex_attn ```
1 parent 2ea6197 commit 8659543

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

torchtitan/experiments/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ We provide this `experiments/` folder to host experiments that add significant v
3030
| [torchcomms](./torchcomms/) | TBA | [@d4l3k](https://https://github.com/d4l3k) [@fduwjj](https://github.com/fduwjj) [@mori360 ](https://github.com/mori360) |
3131
| [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) |
3232
| [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) |
33-
| [compiler_toolkit](./compiler_tookit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) |
33+
| [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) |

torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,29 @@
3030
from torchtitan.tools.logging import logger
3131

3232

33-
def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule:
34-
logger.info("fwd_gm:")
33+
def compiler(name: str, gm: torch.fx.GraphModule, example_inputs):
34+
logger.info(f"{name} before compiler:")
3535
logger.info(gm.print_readable(print_output=False))
36-
return gm
3736

37+
# TODO: regional_inductor should work with deepseek_v3
38+
# gm = regional_inductor(gm, example_inputs)
3839

39-
def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule:
40-
logger.info("bwd_gm:")
40+
logger.info(f"{name} after compiler:")
4141
logger.info(gm.print_readable(print_output=False))
4242
return gm
4343

4444

45+
def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
46+
return compiler("fwd_gm", gm, example_inputs)
47+
48+
49+
def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
50+
return compiler("bwd_gm", gm, example_inputs)
51+
52+
4553
def annotate_deepseekv3() -> None:
4654
from torchtitan.distributed.expert_parallel import ExpertParallel
55+
from torchtitan.models.attention import FlexAttentionWrapper
4756
from torchtitan.models.moe.moe import MoE
4857

4958
# annotate the MoE with dispatch, compute and combine
@@ -55,6 +64,11 @@ def annotate_deepseekv3() -> None:
5564
)
5665
MoE.forward = annotate_fn({"EP": "compute"})(MoE.forward)
5766

67+
# annotate flex_attention with compile_with_inductor
68+
FlexAttentionWrapper.forward = annotate_fn(
69+
{"compile_with_inductor": "flex_attention"}
70+
)(FlexAttentionWrapper.forward)
71+
5872

5973
def parallelize_deepseekv3(
6074
model: torch.nn.Module,
@@ -64,8 +78,7 @@ def parallelize_deepseekv3(
6478

6579
annotate_deepseekv3()
6680

67-
if job_config.model.flavor.endswith("flex_attn"):
68-
register_blockmask_pytree_node()
81+
register_blockmask_pytree_node()
6982

7083
# Disable torch.compile over the model in the compiler toolkit style workflow
7184
with disable_compile(job_config):

torchtitan/experiments/compiler_toolkit/llama3/parallelize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ def parallelize_llama(
8787

8888
annotate_llama()
8989

90-
if job_config.model.flavor.endswith("flex_attn"):
91-
register_blockmask_pytree_node()
90+
register_blockmask_pytree_node()
9291

9392
# Disable torch.compile over the model in the compiler toolkit style workflow
9493
with disable_compile(job_config):

0 commit comments

Comments
 (0)