Skip to content

Commit df7b9cd

Browse files
committed
add manual bucketing pass
1 parent 8659543 commit df7b9cd

File tree

6 files changed

+73
-27
lines changed

6 files changed

+73
-27
lines changed

torchtitan/experiments/simple_fsdp/README.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing
5151

5252
2. auto optimization: perform auto-bucketing & reordering without user inputs. **Note: it is not guaranteed that users will get the most optimized training performance**
5353
- "aot_eager_autobucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend.
54-
55-
56-
users can specify the pass (e.g., "aot_eager_autobucketing") via additional configs:
57-
58-
```bash
59-
--compile.model_backend_override "aot_eager_autobucketing"
60-
```
54+
```bash
55+
--compile.backend "aot_eager" --compile.model_backend_override "aot_eager_autobucketing"
56+
```
57+
58+
3. manual optimization: perform manual bucketing & reordering with user FQN inputs.
59+
- "aot_eager_manualbucketing": perform manual bucketing at aten fx-level, and perform code execution with aot_eager backend.
60+
```bash
61+
--compile.backend "aot_eager" --compile.model_backend_override "aot_eager_manualbucketing" --compile.fsdp_manual_bucketed_modules "tok_embeddings,layers.[0-5],norm+output"
62+
```
6163

6264
### Citation
6365

torchtitan/experiments/simple_fsdp/backend.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,18 @@
88

99
import torch
1010

11+
from .job_config import Compile as CompileConfig
1112

12-
def get_compile_backend(backend_name: str) -> Union[str, callable]:
13+
14+
def get_compile_backend(compile_config: CompileConfig) -> Union[str, callable]:
1315
# return the compile backends used in SimpleFSDP training
1416
# Step1: check if backend_name is inside available torch.compile backends
1517
# Step2: check if the backend_name has been registered as a customized backend
18+
backend_name = (
19+
getattr(compile_config, "model_backend_override", None)
20+
or compile_config.backend
21+
)
22+
1623
available_torch_backend = torch._dynamo.list_backends(exclude_tags=())
1724
if backend_name in available_torch_backend:
1825
return backend_name
@@ -43,6 +50,33 @@ def aten_autobucketing_reordering_pass(
4350
bw_compiler=aten_autobucketing_reordering_pass,
4451
keep_inference_input_mutations=True,
4552
)
53+
elif backend_name == "aot_eager_manualbucketing":
54+
# Perform manual optimization in aten fx-level and execute code in aot_eager backend
55+
# The manualbucketing logic is here:
56+
from functools import partial
57+
58+
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
59+
from torch._inductor.fx_passes.overlap_manual_scheduling import (
60+
manual_overlap_bucketing,
61+
)
62+
63+
torch._inductor.config.allow_buffer_reuse = False
64+
manual_overlap_bucketing = partial(
65+
manual_overlap_bucketing,
66+
module_bucket_plans=compile_config.fsdp_manual_bucketed_modules,
67+
)
68+
69+
def aten_manualbucketing_reordering_pass(
70+
gm: torch.fx.GraphModule, example_inputs: Any
71+
) -> torch.fx.GraphModule:
72+
manual_overlap_bucketing(gm)
73+
return gm
74+
75+
backend = aot_autograd_backend(
76+
fw_compiler=aten_manualbucketing_reordering_pass,
77+
bw_compiler=aten_manualbucketing_reordering_pass,
78+
keep_inference_input_mutations=True,
79+
)
4680
else:
4781
raise AssertionError(f"Unsupported customized backend: {backend_name}")
4882

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def parallelize_deepseekv3(
178178
if job_config.compile.enable:
179179
torch._inductor.config.reorder_for_peak_memory = False
180180
torch._dynamo.config.capture_scalar_outputs = True
181-
model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True)
181+
model = torch.compile(
182+
model, backend=get_compile_backend(job_config.compile), fullgraph=True
183+
)
182184

183185
return model

torchtitan/experiments/simple_fsdp/job_config.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,18 @@
1010
@dataclass
1111
class Compile:
1212
model_backend_override: str | None = None
13-
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
13+
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """
14+
15+
fsdp_manual_bucketed_modules: list[str] = field(default_factory=list)
16+
"""
17+
Manual bucket modules based on user specified FQNs
18+
Abbreviations are supported to make specifying modules easier.
19+
Currently, the following abbreviations are available:
20+
(1) layers.[0-2] -> [layers.0], [layers.1], [layers.2]
21+
(layers are split three separate buckets)
22+
(2) norm+output -> [norm, output]
23+
(norm and output are in one bucket)
24+
"""
1425

1526

1627
@dataclass

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,9 @@ def parallelize_llama(
140140

141141
if job_config.compile.enable and "model" in job_config.compile.components:
142142
torch._inductor.config.reorder_for_peak_memory = False
143-
backend = (
144-
getattr(job_config.compile, "model_backend_override", None)
145-
or job_config.compile.backend
146-
)
147143
model = torch.compile(
148144
model,
149-
backend=get_compile_backend(backend),
145+
backend=get_compile_backend(job_config.compile),
150146
fullgraph=True,
151147
)
152148

torchtitan/experiments/simple_fsdp/tests/integration_tests.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,19 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
2929
"1D",
3030
"1d",
3131
),
32-
OverrideDefinitions(
33-
[
34-
[
35-
"--model.name simple_fsdp.llama3",
36-
"--compile.enable",
37-
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
38-
"--compile.model_backend_override aot_eager_autobucketing",
39-
],
40-
],
41-
"1D+aot_eager_autobucketing",
42-
"1d_aot_eager_autobucketing",
43-
),
32+
# TODO(ruisizhang123): add back after autobucketing pass is mature
33+
# OverrideDefinitions(
34+
# [
35+
# [
36+
# "--model.name simple_fsdp.llama3",
37+
# "--compile.enable",
38+
# "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
39+
# "--compile.model_backend_override aot_eager_autobucketing",
40+
# ],
41+
# ],
42+
# "1D+aot_eager_autobucketing",
43+
# "1d_aot_eager_autobucketing",
44+
# ),
4445
OverrideDefinitions(
4546
[
4647
[

0 commit comments

Comments
 (0)